Пример #1
0
def deserialize_object(identifier: Any,
                       module_objects: Dict,
                       printable_module_name: str = 'object') -> object:
    """Deserialize object using name.

  Args:
    identifier: a string or function.
    module_objects: modules global objects.
    printable_module_name: name of the module,

  Returns:
    deserialized class.

  Raises:
    ValueError: if identifier does not exist or not supported.
  """
    if identifier is None:
        return None

    if isinstance(identifier, str):
        obj = module_objects.get(identifier)
        if obj is None:
            raise ValueError('Unknown ' + printable_module_name + ':' +
                             identifier)
        return obj
    elif tf_inspect.isfunction(identifier):
        return identifier
    else:
        raise ValueError(
            f'Could not interpret serialized {printable_module_name}:{identifier}'
        )
Пример #2
0
def entity_to_graph(o, conversion_map, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    conversion_map: A ConversionMap object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, conversion_map)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, conversion_map, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, conversion_map, arg_values,
                                           arg_types)
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    conversion_map.add_to_cache(o, node)
    if conversion_map.recursive:
        while True:
            candidate = None
            for obj in conversion_map.name_map.keys():
                if obj not in conversion_map.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class')
                    and getattr(candidate, 'im_class')
                    not in conversion_map.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, conversion_map, {}, {})

    return node, name, ns
Пример #3
0
def class_to_graph(c, conversion_map):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = None
  for _, m in members:
    node, _ = function_to_graph(
        m,
        conversion_map=conversion_map,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_namespace is None:
      class_namespace = inspect_utils.getnamespace(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])

  return node, class_name
Пример #4
0
    def _generic_test(self,
                      f_raw,
                      examples,
                      input_signature=None,
                      skip_modes=None):
        """Test a function `f_raw` against all tests `examples`.

    Args:
      f_raw: a callable.
      examples: A list of `Example` named tuples.
      input_signature: Input signature to tf.function.
      skip_modes: A list of `RunMode` enums to entirely skip testing in the
        specified `RunMode`s. This is necessary when things fail in a certain
        `RunMode` even before executing the function (e.g. during saving or
        loading in `RunMode.SAVED` mode).
    """
        f_tf = None
        if not skip_modes:
            skip_modes = []

        if tf_inspect.isfunction(f_raw):
            self.recordProperty('f', tf_inspect.getsource(f_raw))
        else:
            self.recordProperty('f', tf_inspect.getdoc(f_raw))

        for arg, out, failure, bugs in examples:
            del out
            self.recordProperty('Input "{}"'.format(arg), {
                'not-working': failure,
                'bugs': bugs
            })

        # Run the function without tf.function
        if RunMode.RAW not in skip_modes:
            self._run_and_check(f_raw, RunMode.RAW, examples)

        # TF Function
        if RunMode.FUNCTION not in skip_modes:
            f_tf = tf.function(f_raw, input_signature=input_signature)
            self._run_and_check(f_tf, RunMode.FUNCTION, examples)

        # XLA Function
        if RunMode.XLA not in skip_modes:
            f_xla = tf.function(f_raw,
                                input_signature=input_signature,
                                experimental_compile=True)
            self._run_and_check(f_xla, RunMode.XLA, examples)

        # Write a saved model and try to run it
        if RunMode.SAVED not in skip_modes:
            module = tf.Module()
            if f_tf:
                module.f = f_tf
            else:
                module.f = tf.function(f_raw, input_signature=input_signature)

            saved_model_dir = tempfile.gettempdir()
            tf.saved_model.save(module, saved_model_dir)
            module_loaded = tf.saved_model.load(saved_model_dir)
            self._run_and_check(module_loaded.f, RunMode.SAVED, examples)
Пример #5
0
    def __new__(mcs, name, bases, clsdict):
        for key, value in clsdict.items():
            if key == "name_scope":
                continue

            elif key.startswith("__") and key != "__call__":
                # Don't patch methods like `__getattr__` or `__del__`.
                continue

            elif tf_inspect.isfunction(value):
                if getattr(value, NO_MODULE_NAME_SCOPE, False):
                    # The function has been annotated to say that no autoscoping should
                    # be applied, so do not patch it.
                    continue
                clsdict[key] = with_name_scope(value)

            elif isinstance(value, property):
                clsdict[key] = property(value.fget if not value.fget else
                                        with_name_scope(value.fget),
                                        value.fset if not value.fset else
                                        with_name_scope(value.fset),
                                        value.fdel if not value.fdel else
                                        with_name_scope(value.fdel),
                                        doc=value.__doc__)

        return super(ModuleMetaclass, mcs).__new__(mcs, name, bases, clsdict)
Пример #6
0
def to_graph(o, arg_value_hints=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    o: A Python function or class.
    arg_value_hints: A dict mapping parameter names to objects that can hint
        at the type of those parameters.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap()
  _, name = conversion.object_to_graph(o, conversion_map, arg_value_hints)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(o):
    compiled_node.__dict__.update(six.get_function_globals(o))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Пример #7
0
def convert(entity, program_ctx):
    """Converts an entity into an equivalent entity."""

    if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
        if not hasattr(entity, '__code__'):
            raise ValueError(
                'Cannot apply autograph to a function that doesn\'t '
                'expose a __code__ object. If this is a @tf.function,'
                ' try passing f.python_function instead.')
        free_nonglobal_var_names = entity.__code__.co_freevars
    else:
        free_nonglobal_var_names = ()

    for i, name in enumerate(free_nonglobal_var_names):
        if (name == 'ag__'
                and entity.__closure__[i].cell_contents is not ag_internal):
            raise ValueError('entity {} uses the reserved symbol "{}"'.format(
                entity, name))
        # TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered.

    converted_entity_info = _convert_with_cache(entity, program_ctx,
                                                free_nonglobal_var_names)

    return _instantiate(entity, converted_entity_info,
                        free_nonglobal_var_names)
def class_to_graph(c, conversion_map):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = None
    for _, m in members:
        node, _ = function_to_graph(m,
                                    conversion_map=conversion_map,
                                    arg_values={},
                                    arg_types={'self': (c.__name__, c)},
                                    owner_type=c)
        # TODO(mdan): Do not assume all members have the same view of globals.
        if class_namespace is None:
            class_namespace = inspect_utils.getnamespace(m)
        converted_members[m] = node
    namer = conversion_map.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)
    node = gast.ClassDef(class_name,
                         bases=[],
                         keywords=[],
                         body=list(converted_members.values()),
                         decorator_list=[])

    return node, class_name
Пример #9
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, do_not_convert, converted_call),
        partial_types=partial_types,
        api_module=tf_inspect.getmodule(to_graph))
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.extend(parser.parse_str(import_line).body)
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node, compiled_src = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        for key, val in inspect_utils.getnamespace(e).items():
            # Avoid overwriting entities that have been transformed.
            if key not in compiled_node.__dict__:
                compiled_node.__dict__[key] = val
    compiled_fn = getattr(compiled_node, name)

    if verbose:
        logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

    return compiled_fn
Пример #10
0
def is_supported_type_for_deprecation(symbol):
    # Exclude Exception subclasses since users should be able to
    # "except" the same type of exception that was "raised" (i.e. we
    # shouldn't wrap it with deprecation alias).
    # Also, exclude subclasses of namedtuples for now.
    return tf_inspect.isfunction(symbol) or (
        tf_inspect.isclass(symbol) and not issubclass(symbol, Exception)
        and tuple not in tf_inspect.getmro(symbol))
Пример #11
0
def is_supported_type_for_deprecation(symbol):
  # Exclude Exception subclasses since users should be able to
  # "except" the same type of exception that was "raised" (i.e. we
  # shouldn't wrap it with deprecation alias).
  # Also, exclude subclasses of namedtuples for now.
  return tf_inspect.isfunction(symbol) or (
      tf_inspect.isclass(symbol) and not issubclass(symbol, Exception)
      and tuple not in tf_inspect.getmro(symbol))
Пример #12
0
def parse_entity(entity):
  """Returns the AST of given entity."""
  source = tf_inspect.getsource(entity)

  def fail(comment):
    raise ValueError(
        'Failed to parse source code of {}, which Python reported as:\n{}\n'
        '{}'.format(entity, source, comment))

  # Comments and multiline strings can appear at arbitrary indentation levels,
  # causing textwrap.dedent to not correctly dedent source code.
  # TODO(b/115884650): Automatic handling of comments/multiline strings.
  source = textwrap.dedent(source)

  try:
    return parse_str(source), source

  except IndentationError:
    # The text below lists the causes of this error known to us. There may
    # be more.
    fail('This may be caused by multiline strings or comments not indented at'
         'the same level as the code.')

  except SyntaxError as e:
    if not tf_inspect.isfunction(entity) or entity.__name__ != '<lambda>':
      raise

    # Certain entities, like lambdas, only hold the raw code lines which defined
    # them, which may include surrounding tokens and may be syntactically
    # invalid out of context. For example:
    #
    #     l = (
    #         lambda x: x,)[0]
    #
    # will have the dedented source "lambda x: x,)[0]"
    # Here we make an attempt to stip away the garbage by looking at the
    # information in the syntax error.
    lines = source.split('\n')
    lineno, offset = e.lineno, e.offset  # 1-based

    # Give up if there's nothing we can chip away.
    if len(lines) == lineno and len(lines[-1]) == offset:
      fail('If this is a lambda function, the error may be avoided by creating'
           ' the lambda in a standalone statement.')

    # Drop all lines following the error location
    # TODO(mdan): What's with the pylint errors?
    lines = lines[:lineno]  # pylint:disable=invalid-slice-index
    # Drop all characters following the error location
    lines[-1] = lines[-1][:offset - 1]  # pylint:disable=invalid-slice-index
    new_source = '\n'.join(lines)

    try:
      return parse_str(new_source), new_source
    except SyntaxError as e:
      fail('If this is a lambda function, the error may be avoided by creating'
           ' the lambda in a standalone statement. Tried to strip down the'
           ' source to:\n{}\nBut that did not work.'.format(new_source))
Пример #13
0
def to_graph(e,
             recursive=True,
             verbose=False,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    verbose: Whether to output the compiled code in the logs.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, do_not_convert, converted_call),
      partial_types=partial_types,
      api_module=tf_inspect.getmodule(to_graph))
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.extend(parser.parse_str(import_line).body)
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node, compiled_src = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    for key, val in inspect_utils.getnamespace(e).items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_node.__dict__:
        compiled_node.__dict__[key] = val
  compiled_fn = getattr(compiled_node, name)

  if verbose:
    logging.info('Compiled output of %s:\n\n%s\n', e, compiled_src)

  return compiled_fn
Пример #14
0
def islambda(f):
  if not tf_inspect.isfunction(f):
    return False
  # TODO(mdan): Look into checking the only the code object.
  if not (hasattr(f, '__name__') and hasattr(f, '__code__')):
    return False
  # Some wrappers can rename the function, but changing the name of the
  # code object is harder.
  return ((f.__name__ == '<lambda>') or (f.__code__.co_name == '<lambda>'))
Пример #15
0
def entity_to_graph(o, conversion_map, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    conversion_map: A ConversionMap object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, conversion_map)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, conversion_map, arg_values, arg_types)
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  conversion_map.add_to_cache(o, node)
  if conversion_map.recursive:
    while True:
      candidate = None
      for obj in conversion_map.name_map.keys():
        if obj not in conversion_map.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in conversion_map.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, conversion_map, {}, {})

  return node, name, ns
Пример #16
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
      parameters.
    arg_types: A dict containing type hints for symbols like function
      parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'python/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Пример #17
0
def deserialize_keras_object(identifier,
                             module_objects=None,
                             custom_objects=None,
                             printable_module_name='object'):
    """Turns the serialized form of a Keras object back into an actual object."""
    if identifier is None:
        return None

    if isinstance(identifier, dict):
        # In this case we are dealing with a Keras config dictionary.
        config = identifier
        (cls, cls_config) = class_and_config_for_serialized_keras_object(
            config, module_objects, custom_objects, printable_module_name)

        if hasattr(cls, 'from_config'):
            arg_spec = tf_inspect.getfullargspec(cls.from_config)
            custom_objects = custom_objects or {}

            if 'custom_objects' in arg_spec.args:
                return cls.from_config(
                    cls_config,
                    custom_objects=dict(
                        list(_GLOBAL_CUSTOM_OBJECTS.items()) +
                        list(custom_objects.items())))
            with CustomObjectScope(custom_objects):
                return cls.from_config(cls_config)
        else:
            # Then `cls` may be a function returning a class.
            # in this case by convention `config` holds
            # the kwargs of the function.
            custom_objects = custom_objects or {}
            with CustomObjectScope(custom_objects):
                return cls(**cls_config)
    elif isinstance(identifier, six.string_types):
        object_name = identifier
        if custom_objects and object_name in custom_objects:
            obj = custom_objects.get(object_name)
        elif object_name in _GLOBAL_CUSTOM_OBJECTS:
            obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
        else:
            obj = module_objects.get(object_name)
            if obj is None:
                raise ValueError('Unknown ' + printable_module_name + ': ' +
                                 object_name)
        # Classes passed by name are instantiated with no args, functions are
        # returned as-is.
        if tf_inspect.isclass(obj):
            return obj()
        return obj
    elif tf_inspect.isfunction(identifier):
        # If a function has already been deserialized, return as is.
        return identifier
    else:
        raise ValueError('Could not interpret serialized %s: %s' %
                         (printable_module_name, identifier))
Пример #18
0
def _get_func_name(func):
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return "%s.%s" % (func.__self__.__name__, func.__name__)
    else:  # Probably a class instance with __call__
      return type(func)
  else:
    raise ValueError("Argument must be callable")
Пример #19
0
def _get_func_name(func):
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return "%s.%s" % (func.__self__.__name__, func.__name__)
    else:  # Probably a class instance with __call__
      return type(func)
  else:
    raise ValueError("Argument must be callable")
Пример #20
0
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        if not tf_inspect.isfunction(attr):
          continue
        names_v1 = tf_export.get_v1_names(attr)
        arg_names_v1 = get_args(attr)

        for name in names_v1:
          tf_name = "tf.%s" % name
          if tf_name in function_warnings or tf_name in function_transformers:
            continue  # These require manual change
          if tf_name in v1_name_exceptions:
            continue
          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(arg_names_v1)])
          text_input = "%s(%s)" % (tf_name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          if new_function_name == "tf.compat.v1.%s" % name:
            if tf_name in keyword_renames:
              # If we rename arguments, new function must be available in 2.0.
              # We should not be using compat.v1 in this case.
              self.assertFalse(
                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
                  (new_function_name, text_input, text))
            continue
          # 3. Verify V2 function and arguments.
          args_v2 = get_args(self.v2_symbols[new_function_name])
          args_v2.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v2,
                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v2)))
          # 4. Verify that the argument exists in v1 as well.
          if new_function_name in set(["tf.nn.ctc_loss",
                                       "tf.saved_model.save"]):
            continue
          args_v1 = get_args(self.v1_symbols[new_function_name])
          args_v1.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v1,
                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v1)))
Пример #21
0
def get_func_code(func):
    """Returns func_code of passed callable."""
    _, func = tf_decorator.unwrap(func)
    if callable(func):
        if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
            return six.get_function_code(func)
        elif hasattr(func, '__call__'):
            return six.get_function_code(func.__call__)
        else:
            raise ValueError('Unhandled callable, type=%s' % type(func))
    else:
        raise ValueError('Argument must be callable')
Пример #22
0
def get_func_code(func):
  """Returns func_code of passed callable."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
      return six.get_function_code(func)
    elif hasattr(func, '__call__'):
      return six.get_function_code(func.__call__)
    else:
      raise ValueError('Unhandled callable, type=%s' % type(func))
  else:
    raise ValueError('Argument must be callable')
Пример #23
0
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
    """Creates a converted instance and binds it to match original entity."""
    factory = converted_entity_info.get_factory()

    # `factory` is currently bound to the empty module it was loaded from.
    # It must instead be bound to the globals and closure from the original
    # entity.
    if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
        entity_globals = entity.__globals__
        entity_closure = entity.__closure__ or ()
    elif hasattr(entity, '__module__'):
        entity_globals = sys.modules[entity.__module__].__dict__
        entity_closure = ()
    assert len(entity_closure) == len(free_nonglobal_var_names)

    # Fit the original entity's cells to match the order of factory's cells.
    original_names_and_cells = dict(
        zip(free_nonglobal_var_names, entity_closure))
    new_factory_cells = tuple(original_names_and_cells[name]
                              for name in factory.__code__.co_freevars)

    bound_factory = types.FunctionType(code=factory.__code__,
                                       globals=entity_globals,
                                       name=factory.__name__,
                                       argdefs=(),
                                       closure=new_factory_cells)

    # Two other free vars: the internal "ag__" module and the source
    # map. These are wired via the parameters of the factory.
    converted_entity = bound_factory(  # pylint:disable=not-callable
        ag_internal, converted_entity_info.source_map,
        converted_entity_info.get_module())

    if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
        # Attach the default argument to the converted function.
        converted_entity.__defaults__ = entity.__defaults__
        if hasattr(entity, '__kwdefaults__'):
            converted_entity.__kwdefaults__ = entity.__kwdefaults__

    return converted_entity
Пример #24
0
def get_func_name(func):
  """Returns name of passed callable."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return '%s.%s' % (six.get_method_self(func).__class__.__name__,
                        six.get_method_function(func).__name__)
    else:  # Probably a class instance with __call__
      return str(type(func))
  else:
    raise ValueError('Argument must be callable')
Пример #25
0
def getfutureimports(entity):
  """Detects what future imports are necessary to safely execute entity source.

  Args:
    entity: Any object

  Returns:
    A tuple of future strings
  """
  if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)):
    return tuple()
  return tuple(sorted(name for name, value in entity.__globals__.items()
                      if getattr(value, '__module__', None) == '__future__'))
Пример #26
0
def get_func_name(func):
    """Returns name of passed callable."""
    _, func = tf_decorator.unwrap(func)
    if callable(func):
        if tf_inspect.isfunction(func):
            return func.__name__
        elif tf_inspect.ismethod(func):
            return '%s.%s' % (six.get_method_self(func).__class__.__name__,
                              six.get_method_function(func).__name__)
        else:  # Probably a class instance with __call__
            return type(func)
    else:
        raise ValueError('Argument must be callable')
Пример #27
0
def getfutureimports(entity):
  """Detects what future imports are necessary to safely execute entity source.

  Args:
    entity: Any object

  Returns:
    A tuple of future strings
  """
  if not (tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity)):
    return tuple()
  return tuple(sorted(name for name, value in entity.__globals__.items()
                      if getattr(value, '__module__', None) == '__future__'))
Пример #28
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Пример #29
0
def class_and_config_for_serialized_keras_object(
        config,
        module_objects=None,
        custom_objects=None,
        printable_module_name='object'):
    """Returns the class name and config for a serialized keras object."""
    if (not isinstance(config, dict) or 'class_name' not in config
            or 'config' not in config):
        raise ValueError('Improper config format: ' + str(config))

    class_name = config['class_name']
    cls = get_registered_object(class_name, custom_objects, module_objects)
    if cls is None:
        raise ValueError('Unknown ' + printable_module_name + ': ' +
                         class_name)

    cls_config = config['config']
    # Check if `cls_config` is a list. If it is a list, return the class and the
    # associated class configs for recursively deserialization. This case will
    # happen on the old version of sequential model (e.g. `keras_version` ==
    # "2.0.6"), which is serialized in a different structure, for example
    # "{'class_name': 'Sequential',
    #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
    if isinstance(cls_config, list):
        return (cls, cls_config)

    deserialized_objects = {}
    for key, item in cls_config.items():
        if isinstance(item, dict) and '__passive_serialization__' in item:
            deserialized_objects[key] = deserialize_keras_object(
                item,
                module_objects=module_objects,
                custom_objects=custom_objects,
                printable_module_name='config_item')
        # TODO(momernick): Should this also have 'module_objects'?
        elif (isinstance(item, six.string_types) and tf_inspect.isfunction(
                get_registered_object(item, custom_objects))):
            # Handle custom functions here. When saving functions, we only save the
            # function's name as a string. If we find a matching string in the custom
            # objects during deserialization, we convert the string back to the
            # original function.
            # Note that a potential issue is that a string field could have a naming
            # conflict with a custom function name, but this should be a rare case.
            # This issue does not occur if a string field has a naming conflict with
            # a custom object, since the config of an object will always be a dict.
            deserialized_objects[key] = get_registered_object(
                item, custom_objects)
    for key, item in deserialized_objects.items():
        cls_config[key] = deserialized_objects[key]

    return (cls, cls_config)
Пример #30
0
def generate_global_index(library_name, index, reference_resolver):
    """Given a dict of full names to python objects, generate an index page.

  The index page generated contains a list of links for all symbols in `index`
  that have their own documentation page.

  Args:
    library_name: The name for the documented library to use in the title.
    index: A dict mapping full names to python objects.
    reference_resolver: An instance of ReferenceResolver.

  Returns:
    A string containing an index page as Markdown.
  """
    symbol_links = []
    for full_name, py_object in six.iteritems(index):
        if (tf_inspect.ismodule(py_object) or tf_inspect.isfunction(py_object)
                or tf_inspect.isclass(py_object)):
            # In Python 3, unbound methods are functions, so eliminate those.
            if tf_inspect.isfunction(py_object):
                if full_name.count('.') == 0:
                    parent_name = ''
                else:
                    parent_name = full_name[:full_name.rfind('.')]
                if parent_name in index and tf_inspect.isclass(
                        index[parent_name]):
                    # Skip methods (=functions with class parents).
                    continue
            symbol_links.append(
                (full_name,
                 reference_resolver.python_link(full_name, full_name, '.')))

    lines = ['# All symbols in %s' % library_name, '']
    for _, link in sorted(symbol_links, key=lambda x: x[0]):
        lines.append('*  %s' % link)

    # TODO(markdaoust): use a _ModulePageInfo -> prety_docs.build_md_page()
    return '\n'.join(lines)
Пример #31
0
def _instantiate(entity, converted_entity_info, free_nonglobal_var_names):
  """Creates a converted instance and binds it to match original entity."""
  factory = converted_entity_info.get_factory()

  # `factory` is currently bound to the empty module it was loaded from.
  # It must instead be bound to the globals and closure from the original
  # entity.
  if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
    entity_globals = entity.__globals__
    entity_closure = entity.__closure__ or ()
  elif hasattr(entity, '__module__'):
    entity_globals = sys.modules[entity.__module__].__dict__
    entity_closure = ()
  assert len(entity_closure) == len(free_nonglobal_var_names)

  # Fit the original entity's cells to match the order of factory's cells.
  original_names_and_cells = dict(zip(free_nonglobal_var_names, entity_closure))
  new_factory_cells = tuple(
      original_names_and_cells[name] for name in factory.__code__.co_freevars)

  bound_factory = types.FunctionType(
      code=factory.__code__,
      globals=entity_globals,
      name=factory.__name__,
      argdefs=(),
      closure=new_factory_cells)

  # Two other free vars: the internal "ag__" module and the source
  # map. These are wired via the parameters of the factory.
  converted_entity = bound_factory(  # pylint:disable=not-callable
      ag_internal, converted_entity_info.source_map,
      converted_entity_info.get_module())

  if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
    # Attach the default argument to the converted function.
    converted_entity.__defaults__ = entity.__defaults__

  return converted_entity
Пример #32
0
def object_to_graph(o, conversion_map, value_hints):
  """Compile a Python object into equivalent TensorFlow.

  The function will also recursively compile all the objects that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python object.
    conversion_map: A ConversionMap object.
    value_hints: A dict containing value hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name):
        * ast: An AST representing an object with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new object can be found.

  Raises:
    ValueError: if the object is not supported.
  """
  if value_hints is None:
    value_hints = {}

  if tf_inspect.isclass(o):
    node, new_name = class_to_graph(o, conversion_map, value_hints)
  elif tf_inspect.isfunction(o):
    node, new_name = function_to_graph(o, conversion_map, value_hints)
  elif tf_inspect.ismethod(o):
    node, new_name = function_to_graph(o, conversion_map, value_hints)
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  conversion_map.add_to_cache(o, node)
  if conversion_map.recursive:
    for obj in conversion_map.name_map.keys():
      if obj not in conversion_map.dependency_cache:
        if (hasattr(obj, 'im_class') and
            getattr(obj, 'im_class') not in conversion_map.partial_types):
          # Class members are converted with their objects, unless they're
          # only converted partially.
          continue
        object_to_graph(obj, conversion_map, None)

  return node, new_name
Пример #33
0
def to_graph(e,
             recursive=True,
             arg_values=None,
             arg_types=None,
             partial_types=None):
    """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
    conversion_map = conversion.ConversionMap(
        recursive=recursive,
        nocompile_decorators=(convert, graph_ready, convert_inline),
        partial_types=partial_types)
    _, name = conversion.entity_to_graph(e, conversion_map, arg_values,
                                         arg_types)

    module = gast.Module([])
    for import_line in config.COMPILED_IMPORT_STATEMENTS:
        module.body.append(parser.parse_str(import_line))
    for dep in conversion_map.dependency_cache.values():
        module.body.append(dep)
    compiled_node = compiler.ast_to_object(module)

    # The compiled code should see everything the entry function saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    if tf_inspect.isfunction(e):
        compiled_node.__dict__.update(six.get_function_globals(e))

    compiled_fn = getattr(compiled_node, name)
    return compiled_fn
Пример #34
0
def object_to_graph(o, conversion_map, value_hints):
  """Compile a Python object into equivalent TensorFlow.

  The function will also recursively compile all the objects that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python object.
    conversion_map: A ConversionMap object.
    value_hints: A dict containing value hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name):
        * ast: An AST representing an object with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new object can be found.

  Raises:
    ValueError: if the object is not supported.
  """
  if value_hints is None:
    value_hints = {}

  if tf_inspect.isclass(o):
    node, new_name = class_to_graph(o, conversion_map, value_hints)
  elif tf_inspect.isfunction(o):
    node, new_name = function_to_graph(o, conversion_map, value_hints)
  elif tf_inspect.ismethod(o):
    node, new_name = function_to_graph(o, conversion_map, value_hints)
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  conversion_map.add_to_cache(o, node)
  if conversion_map.recursive:
    for obj in conversion_map.name_map.keys():
      if obj not in conversion_map.dependency_cache:
        if (hasattr(obj, 'im_class') and
            getattr(obj, 'im_class') not in conversion_map.partial_types):
          # Class members are converted with their objects, unless they're
          # only converted partially.
          continue
        object_to_graph(obj, conversion_map, None)

  return node, new_name
Пример #35
0
def to_graph(e,
             recursive=True,
             arg_values=None,
             arg_types=None,
             partial_types=None):
  """Compile a Python entity into equivalent TensorFlow code.

  Currently supported entities:
    * functions
    * classes

  Classes are handled by converting all their methods into a new class.

  Args:
    e: A Python entity.
    recursive: Whether to recusrively convert any functions that the decorator
        function may call.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.
    partial_types: A set of types (e.g. classes) that will not be converted
        entirely. Calls to member functions for these types will be renamed
        independently.

  Returns:
    A function with a signature identical to `o`, but which when executed it
  creates TF a graph that has the same functionality as the original entity.
  """
  conversion_map = conversion.ConversionMap(
      recursive=recursive,
      nocompile_decorators=(convert, graph_ready, convert_inline),
      partial_types=partial_types)
  _, name = conversion.entity_to_graph(e, conversion_map, arg_values, arg_types)

  module = gast.Module([])
  for import_line in config.COMPILED_IMPORT_STATEMENTS:
    module.body.append(parser.parse_str(import_line))
  for dep in conversion_map.dependency_cache.values():
    module.body.append(dep)
  compiled_node = compiler.ast_to_object(module)

  # The compiled code should see everything the entry function saw.
  # TODO(mdan): This might not work well if the call tree spans modules?
  if tf_inspect.isfunction(e):
    compiled_node.__dict__.update(six.get_function_globals(e))

  compiled_fn = getattr(compiled_node, name)
  return compiled_fn
Пример #36
0
def _is_free_function(py_object, full_name, index):
    """Check if input is a free function (and not a class- or static method)."""
    if not tf_inspect.isfunction(py_object):
        return False

    # Static methods are functions to tf_inspect (in 2.7), so check if the parent
    # is a class. If there is no parent, it's not a function.
    if '.' not in full_name:
        return False

    parent_name = full_name.rsplit('.', 1)[0]
    if tf_inspect.isclass(index[parent_name]):
        return False

    return True
Пример #37
0
def generate_global_index(library_name, index, reference_resolver):
  """Given a dict of full names to python objects, generate an index page.

  The index page generated contains a list of links for all symbols in `index`
  that have their own documentation page.

  Args:
    library_name: The name for the documented library to use in the title.
    index: A dict mapping full names to python objects.
    reference_resolver: An instance of ReferenceResolver.

  Returns:
    A string containing an index page as Markdown.
  """
  symbol_links = []
  for full_name, py_object in six.iteritems(index):
    if (tf_inspect.ismodule(py_object) or tf_inspect.isfunction(py_object) or
        tf_inspect.isclass(py_object)):
      # In Python 3, unbound methods are functions, so eliminate those.
      if tf_inspect.isfunction(py_object):
        if full_name.count('.') == 0:
          parent_name = ''
        else:
          parent_name = full_name[:full_name.rfind('.')]
        if parent_name in index and tf_inspect.isclass(index[parent_name]):
          # Skip methods (=functions with class parents).
          continue
      symbol_links.append((
          full_name, reference_resolver.python_link(full_name, full_name, '.')))

  lines = ['# All symbols in %s' % library_name, '']
  for _, link in sorted(symbol_links, key=lambda x: x[0]):
    lines.append('*  %s' % link)

  # TODO(markdaoust): use a _ModulePageInfo -> prety_docs.build_md_page()
  return '\n'.join(lines)
Пример #38
0
def _is_free_function(py_object, full_name, index):
  """Check if input is a free function (and not a class- or static method)."""
  if not tf_inspect.isfunction(py_object):
    return False

  # Static methods are functions to tf_inspect (in 2.7), so check if the parent
  # is a class. If there is no parent, it's not a function.
  if '.' not in full_name:
    return False

  parent_name = full_name.rsplit('.', 1)[0]
  if tf_inspect.isclass(index[parent_name]):
    return False

  return True
Пример #39
0
def get_func_code(func):
  """Returns func_code of passed callable, or None if not available."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
      return six.get_function_code(func)
    # Since the object is not a function or method, but is a callable, we will
    # try to access the __call__method as a function.  This works with callable
    # classes but fails with functool.partial objects despite their __call__
    # attribute.
    try:
      return six.get_function_code(func.__call__)
    except AttributeError:
      return None
  else:
    raise ValueError('Argument must be callable')
Пример #40
0
def get_func_code(func):
  """Returns func_code of passed callable, or None if not available."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
      return six.get_function_code(func)
    # Since the object is not a function or method, but is a callable, we will
    # try to access the __call__method as a function.  This works with callable
    # classes but fails with functool.partial objects despite their __call__
    # attribute.
    try:
      return six.get_function_code(func.__call__)
    except AttributeError:
      return None
  else:
    raise ValueError('Argument must be callable')
Пример #41
0
def _loop_fn_has_config(loop_fn):
    """Test if `loop_fn` has a `pfor_config` argument."""
    if tf_inspect.isfunction(loop_fn):
        argspec = tf_inspect.getargspec(loop_fn)
        return PFOR_CONFIG_ARG in argspec.args
    elif isinstance(loop_fn, functools.partial):
        fn = loop_fn.func
        argspec = tf_inspect.getargspec(fn)
        return (PFOR_CONFIG_ARG in argspec.args
                and PFOR_CONFIG_ARG not in loop_fn.keywords)
    else:
        loop_class = tf_decorator.unwrap(loop_fn)[1]
        if not hasattr(loop_class, "__call__"):
            raise ValueError("loop_fn object did not have a __call__ method")
        argspec = tf_inspect.getargspec(loop_class.__call__)
        return PFOR_CONFIG_ARG in argspec.args
Пример #42
0
def _loop_fn_has_config(loop_fn):
  """Test if `loop_fn` has a `pfor_config` argument."""
  if tf_inspect.isfunction(loop_fn):
    argspec = tf_inspect.getargspec(loop_fn)
    return PFOR_CONFIG_ARG in argspec.args
  elif isinstance(loop_fn, functools.partial):
    fn = loop_fn.func
    argspec = tf_inspect.getargspec(fn)
    return (PFOR_CONFIG_ARG in argspec.args and
            PFOR_CONFIG_ARG not in loop_fn.keywords)
  else:
    loop_class = tf_decorator.unwrap(loop_fn)[1]
    if not hasattr(loop_class, "__call__"):
      raise ValueError("loop_fn object did not have a __call__ method")
    argspec = tf_inspect.getargspec(loop_class.__call__)
    return PFOR_CONFIG_ARG in argspec.args
Пример #43
0
def class_and_config_for_serialized_keras_object(
        config,
        module_objects=None,
        custom_objects=None,
        printable_module_name='object'):
    """Returns the class name and config for a serialized keras object."""
    if (not isinstance(config, dict) or 'class_name' not in config
            or 'config' not in config):
        raise ValueError('Improper config format: ' + str(config))

    class_name = config['class_name']
    if custom_objects and class_name in custom_objects:
        cls = custom_objects[class_name]
    elif class_name in _GLOBAL_CUSTOM_OBJECTS:
        cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
    else:
        module_objects = module_objects or {}
        cls = module_objects.get(class_name)
        if cls is None:
            raise ValueError('Unknown ' + printable_module_name + ': ' +
                             class_name)

    cls_config = config['config']
    deserialized_objects = {}
    for key, item in cls_config.items():
        if isinstance(item, dict) and '__passive_serialization__' in item:
            deserialized_objects[key] = deserialize_keras_object(
                item,
                module_objects=module_objects,
                custom_objects=custom_objects,
                printable_module_name='config_item')
        elif (isinstance(item, six.string_types) and tf_inspect.isfunction(
                _get_custom_objects_by_name(item, custom_objects))):
            # Handle custom functions here. When saving functions, we only save the
            # function's name as a string. If we find a matching string in the custom
            # objects during deserialization, we convert the string back to the
            # original function.
            # Note that a potential issue is that a string field could have a naming
            # conflict with a custom function name, but this should be a rare case.
            # This issue does not occur if a string field has a naming conflict with
            # a custom object, since the config of an object will always be a dict.
            deserialized_objects[key] = _get_custom_objects_by_name(
                item, custom_objects)
    for key, item in deserialized_objects.items():
        cls_config[key] = deserialized_objects[key]

    return (cls, cls_config)
Пример #44
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
Пример #45
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif hasattr(o, '__class__'):
    # Note: this should only be raised when attempting to convert the object
    # directly. converted_call should still support it.
    raise NotImplementedError(
        'cannot convert entity "{}": object conversion is not yet'
        ' supported.'.format(o))
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Пример #46
0
    def collect_docs_for_module(self, parser_config):
        """Collect information necessary specifically for a module's doc page.

    Mainly this is information about the members of the module.

    Args:
      parser_config: An instance of ParserConfig.
    """
        relative_path = os.path.relpath(
            path='.',
            start=os.path.dirname(documentation_path(self.full_name)) or '.')

        member_names = parser_config.tree.get(self.full_name, [])
        for name in member_names:

            if name in [
                    '__builtins__', '__doc__', '__file__', '__name__',
                    '__path__', '__package__'
            ]:
                continue

            member_full_name = self.full_name + '.' + name if self.full_name else name
            member = parser_config.py_name_to_object(member_full_name)

            member_doc = _parse_md_docstring(member, relative_path,
                                             parser_config.reference_resolver)

            url = parser_config.reference_resolver.reference_to_url(
                member_full_name, relative_path)

            if tf_inspect.ismodule(member):
                self._add_module(name, member_full_name, member, member_doc,
                                 url)

            elif tf_inspect.isclass(member):
                self._add_class(name, member_full_name, member, member_doc,
                                url)

            elif tf_inspect.isfunction(member):
                self._add_function(name, member_full_name, member, member_doc,
                                   url)

            else:
                self._add_other_member(name, member_full_name, member,
                                       member_doc)
Пример #47
0
def _get_raw_docstring(py_object):
  """Get the docs for a given python object.

  Args:
    py_object: A python object to retrieve the docs for (class, function/method,
      or module).

  Returns:
    The docstring, or the empty string if no docstring was found.
  """
  # For object instances, tf_inspect.getdoc does give us the docstring of their
  # type, which is not what we want. Only return the docstring if it is useful.
  if (tf_inspect.isclass(py_object) or tf_inspect.ismethod(py_object) or
      tf_inspect.isfunction(py_object) or tf_inspect.ismodule(py_object) or
      isinstance(py_object, property)):
    return tf_inspect.getdoc(py_object) or ''
  else:
    return ''
Пример #48
0
def getcallargs(c, *args, **kwargs):
  """Extension of getcallargs to non-function callables."""
  if tf_inspect.isfunction(c):
    # The traditional getcallargs
    return tf_inspect.getcallargs(c, *args, **kwargs)

  if tf_inspect.isclass(c):
    # Constructors: pass a fake None for self, then remove it.
    arg_map = tf_inspect.getcallargs(c.__init__, None, *args, **kwargs)
    assert 'self' in arg_map, 'no "self" argument, is this not a constructor?'
    del arg_map['self']
    return arg_map

  if hasattr(c, '__call__'):
    # Callable objects: map self to the object itself
    return tf_inspect.getcallargs(c.__call__, *args, **kwargs)

  raise NotImplementedError('unknown callable "%s"' % type(c))
Пример #49
0
def convert(entity, program_ctx):
  """Converts an entity into an equivalent entity."""

  if tf_inspect.isfunction(entity) or tf_inspect.ismethod(entity):
    free_nonglobal_var_names = entity.__code__.co_freevars
  else:
    free_nonglobal_var_names = ()

  for i, name in enumerate(free_nonglobal_var_names):
    if (name == 'ag__' and
        entity.__closure__[i].cell_contents is not ag_internal):
      raise ValueError('entity {} uses the reserved symbol "{}"'.format(
          entity, name))
    # TODO(mdan): In extreme cases, other ag__ symbols may also be clobbered.

  converted_entity_info = _convert_with_cache(
      entity, program_ctx, free_nonglobal_var_names)

  return _instantiate(entity, converted_entity_info, free_nonglobal_var_names)
Пример #50
0
  def __new__(mcs, name, bases, clsdict):
    for key, value in clsdict.items():
      if key in ("__init__", "name_scope"):
        continue

      elif tf_inspect.isfunction(value):
        if getattr(value, "_no_module_name_scope", False):
          # The function has been annotated to say that no autoscoping should
          # be applied, so do not patch it.
          continue
        clsdict[key] = with_name_scope(value)

      elif isinstance(value, property):
        clsdict[key] = property(
            value.fget if not value.fget else with_name_scope(value.fget),
            value.fset if not value.fset else with_name_scope(value.fset),
            value.fdel if not value.fdel else with_name_scope(value.fdel),
            doc=value.__doc__)

    return type.__new__(mcs, name, bases, clsdict)
Пример #51
0
  def collect_docs_for_module(self, parser_config):
    """Collect information necessary specifically for a module's doc page.

    Mainly this is information about the members of the module.

    Args:
      parser_config: An instance of ParserConfig.
    """
    relative_path = os.path.relpath(
        path='.',
        start=os.path.dirname(documentation_path(self.full_name)) or '.')

    member_names = parser_config.tree.get(self.full_name, [])
    for name in member_names:

      if name in ['__builtins__', '__doc__', '__file__',
                  '__name__', '__path__', '__package__']:
        continue

      member_full_name = self.full_name + '.' + name if self.full_name else name
      member = parser_config.py_name_to_object(member_full_name)

      member_doc = _parse_md_docstring(member, relative_path,
                                       parser_config.reference_resolver)

      url = parser_config.reference_resolver.reference_to_url(
          member_full_name, relative_path)

      if tf_inspect.ismodule(member):
        self._add_module(name, member_full_name, member, member_doc, url)

      elif tf_inspect.isclass(member):
        self._add_class(name, member_full_name, member, member_doc, url)

      elif tf_inspect.isfunction(member):
        self._add_function(name, member_full_name, member, member_doc, url)

      else:
        self._add_other_member(name, member_full_name, member, member_doc)
Пример #52
0
def has_deprecation_decorator(symbol):
  """Checks if given object has a deprecation decorator.

  We check if deprecation decorator is in decorators as well as
  whether symbol is a class whose __init__ method has a deprecation
  decorator.
  Args:
    symbol: Python object.

  Returns:
    True if symbol has deprecation decorator.
  """
  decorators, symbol = tf_decorator.unwrap(symbol)
  if contains_deprecation_decorator(decorators):
    return True
  if tf_inspect.isfunction(symbol):
    return False
  if not tf_inspect.isclass(symbol):
    return False
  if not hasattr(symbol, '__init__'):
    return False
  init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
  return contains_deprecation_decorator(init_decorators)
Пример #53
0
  def compiled_function_name(self,
                             original_fqn,
                             live_entity=None,
                             owner_type=None):
    """See call_trees.FunctionNamer.compiled_function_name."""

    if not self.recursive:
      return None, False

    if (live_entity is not None and tf_inspect.isfunction(live_entity) and
        live_entity.__name__ == '<lambda>'):
      return None, False

    if owner_type is not None and owner_type not in self.partial_types:
      # Members are not renamed when part of an entire converted class.
      return None, False

    if isinstance(original_fqn, tuple):
      original_name = '__'.join(original_fqn)
    else:
      original_name = original_fqn

    if live_entity is not None and live_entity in self.renamed_calls:
      return self.renamed_calls[live_entity], True

    new_name_root = 'tf__%s' % original_name
    new_name = new_name_root
    n = 0
    while new_name in self.global_namespace:
      n += 1
      new_name = '%s_%d' % (new_name_root, n)

    if live_entity is not None:
      self.renamed_calls[live_entity] = new_name
    self.generated_names.add(new_name)

    return new_name, True
Пример #54
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  if program_ctx.options.verbose == converter.Verbosity.VERBOSE:
    logging.info('Converting {}'.format(o))

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'contrib/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if program_ctx.options.verbose == converter.Verbosity.VERBOSE:
    logging.info('Compiled output of {}:\n\n{}\n'.format(
        o, compiler.ast_to_source(node)))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns
Пример #55
0
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
Пример #56
0
def converted_call(f, recursive, verbose, arg_types, *args, **kwargs):
  """Compiles a function call inline."""
  # TODO(mdan): This needs cleanup.
  # In particular, we may want to avoid renaming functions altogether.

  if conversion.is_whitelisted_for_graph(f):
    return f(*args, **kwargs)

  unknown_arg_value = object()  # Sentinel for arguments of unknown value

  if inspect_utils.isbuiltin(f):
    return builtins.dynamic_builtin(f, *args, **kwargs)

  if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
    # Regular functions
    target_entity = f
    arg_map_target = f
    effective_args = args
    f_class = inspect_utils.getmethodclass(f)

    if f_class is not None:
      partial_types = (f_class,)
    else:
      partial_types = ()

  elif tf_inspect.isclass(f):
    # Constructors
    target_entity = f
    arg_map_target = f.__init__
    effective_args = args
    partial_types = ()

  elif hasattr(f, '__call__') and hasattr(f, '__class__'):
    # Callable objects
    target_entity = f.__call__
    arg_map_target = f.__call__
    effective_args = (f,) + args
    partial_types = (f.__class__,)

  else:
    NotImplementedError('unknown callable type "%s"' % type(f))

  arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
  for name, arg in arg_values.items():
    if arg is unknown_arg_value:
      continue
    arg_class = arg.__class__
    # If arg_value_hints specifies any name, use that instead.
    if name not in arg_types:
      arg_types[name] = (arg_class.__name__, arg_class)

  # When called from within a decorator, this is the only indication that
  # the function is a method - it appears that the decorator is applied
  # before the method is bound.
  if not partial_types:
    if 'self' in arg_values:
      if tf_inspect.isclass(arg_values['self'].__class__):
        partial_types = (arg_values['self'].__class__,)
    elif 'cls' in arg_values:
      if tf_inspect.isclass(arg_values['cls']):
        partial_types = (arg_values['cls'],)

  converted_f = to_graph(
      target_entity,
      recursive=recursive,
      verbose=verbose,
      arg_values=arg_values,
      arg_types=arg_types,
      partial_types=partial_types)
  return converted_f(*effective_args, **kwargs)
Пример #57
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  logging.log(1,
              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
              f, owner, args, kwargs)

  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if inspect_utils.isbuiltin(f):
    return py_builtins.overload_of(f)(*args, **kwargs)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  # Note: TF linter disallows importing inspect.
  if any(f in m.__dict__.values()
         for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs)

  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
    return _call_unconverted(f, args, kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      arg_map_target = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        # If this is a method call, it may or may not include self.
        #
        # Example when self is included:
        #   converted_call(to_graph(foo.bar), foo)
        #
        # Example when self is not included:
        #   super(...).foo(args)
        #
        if owner is not None and (not args or args[0] is not owner):
          effective_args = (owner,) + args
        else:
          # When the owner is not specified, use the result of
          # inspect_utils.getmethodclass.
          # TODO(b/119246461): Make sure an owner is always specified.
          if not args or args[0] is not f_self:
            effective_args = (f_self,) + args
          else:
            effective_args = (f_self,) + args[1:]
        partial_types = (f_self,)
      else:
        effective_args = args
        partial_types = ()

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      arg_map_target = f.__init__
      effective_args = args
      partial_types = ()

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      arg_map_target = f.__call__
      effective_args = (f,) + args
      partial_types = (f.__class__,)

    else:
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
    arg_types = {}
    for name, arg in arg_values.items():
      arg_class = arg.__class__
      arg_types[name] = (arg_class.__name__, arg_class)

    # When called from within a decorator, this is the only indication that
    # the function is a method - it appears that the decorator is applied
    # before the method is bound.
    if not partial_types:
      if 'self' in arg_values:
        if tf_inspect.isclass(arg_values['self'].__class__):
          partial_types = (arg_values['self'].__class__,)
      elif 'cls' in arg_values:
        if tf_inspect.isclass(arg_values['cls']):
          partial_types = (arg_values['cls'],)

    logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
                partial_types)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=arg_values,
        arg_types=arg_types,
        experimental_optional_features=options.optional_features,
        experimental_strip_decorators=options.strip_decorators,
        experimental_verbose=options.verbose,
        experimental_partial_types=partial_types)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    logging.warn(
        'Entity %s could not be transformed and will be staged without change.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
        ' AutoGraph team. Cause: %s', target_entity, e)

    return _call_unconverted(f, args, kwargs)

  result = converted_f(*effective_args, **kwargs)

  # The converted function's closure is simply inserted into the function's
  # module __dict__. Since modules are permanently cached, that results in
  # leaking the entire closure.
  # Normally, it's not safe to delete the module because that may release said
  # closure as well. However, in the case of converted_call we are certain the
  # function will not be executed again, so the closure should no longer be
  # needed so long as the function doesn't return any executable code.
  # TODO(mdan): Attach the closure properly, using cells.
  if all(map(_is_not_callable, nest.flatten(result))):
    del sys.modules[converted_f.__module__]

  return result