예제 #1
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
    context = converter.EntityContext(namer, entity_info, program_ctx,
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
    elif do_rename:
        node.name = new_name
        assert node.name == new_name

    return (node, ), new_name, entity_info
예제 #2
def _attach_metadata(e, f):
  """Augments an error with the metadata necessary for rewrite."""
  if hasattr(e, 'ag_pass_through'):

  metadata = getattr(e, 'ag_error_metadata', None)
  source_map = f.ag_source_map

  if metadata is None:
    logging.log(1, 'Caught error in user callable %s', f, exc_info=True)
    message = '{}: {}'.format(e.__class__.__name__, e)
    message = None

  cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]

  e.ag_error_metadata = _ErrorMetadata(
      cause_tb, metadata, message, source_map, __file__)
예제 #3
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
    """Returns a (possibly cached) factory for the converted result of entity."""
    # The cache subkey encompasses any conversion options on which the generated
    # code may depend.
    # The cached factory includes the necessary definitions to distinguish
    # between the global and non-global free variables. For this reason, the
    # cache subkey includes the names of the free non-globals.
    subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

    with _CACHE_LOCK:
        # The cache values are _ConvertedEntityFactoryInfo objects.
        if _CACHE.has(entity, subkey):
            # TODO(mdan): Check whether the module is still loaded.
            converted_entity_info = _CACHE[entity][subkey]
            logging.log(3, 'Cache hit for entity %s subkey %s: %s', entity,
                        subkey, converted_entity_info)
            return converted_entity_info

        logging.log(1, 'Entity %s is not cached for subkey %s', entity, subkey)

        nodes, converted_name, entity_info = convert_entity_to_ast(
            entity, program_ctx)

        namer = naming.Namer(entity_info.namespace)
        factory_factory_name = namer.new_symbol(
            'create_converted_entity_factory', ())
        factory_name = namer.new_symbol('create_converted_entity', ())
        nodes = _wrap_into_dynamic_factory(nodes, converted_name,
                                           factory_factory_name, factory_name,

        module, _, source_map = loader.load_ast(nodes, include_source_map=True)
        module_name = module.__name__

        converted_entity_info = _ConvertedEntityFactoryInfo(
        _CACHE[entity][subkey] = converted_entity_info
        return converted_entity_info
예제 #4
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

    o: A Python entity.
    program_ctx: A ProgramContext object.

    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

    NotImplementedError: if entity is of a type that is not yet supported.
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
        raise NotImplementedError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
예제 #5
def is_whitelisted(o,
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow

    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.
    allow_namedtuple_subclass: Reserved for internal use. When `True`,
      namedtuple subclasses are not whitelisted.

    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
        m = tf_inspect.getmodule(o)

    # Examples of callables that lack a __module__ property include builtins.
    if hasattr(m, '__name__'):
        for rule in config.CONVERSION_RULES:
            action = rule.get_action(m)
            if action == config.Action.CONVERT:
                logging.log(2, 'Not whitelisted: %s: %s', o, rule)
                return False
            elif action == config.Action.DO_NOT_CONVERT:
                logging.log(2, 'Whitelisted: %s: %s', o, rule)
                return True

    # The check for __code__ below is because isgeneratorfunction crashes
    # without one.
    if hasattr(o, '__code__') and tf_inspect.isgeneratorfunction(o):
            'Entity %s appears to be a generator function. It will not be converted'
            ' by AutoGraph.', o)
                    'Whitelisted: %s: generator functions are not converted',
        return True

    if (check_call_override and not tf_inspect.isclass(o)
            and hasattr(o, '__call__')):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #   class Custom(tf.Foo):
        #     pass
        #   baz = Custom()
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        # if owner_class is function.TfMethodTarget:
        #   owner_class = o.__self__.target_class
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted(owner_class,
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if allow_namedtuple_subclass:
            if not any(
                    inspect_utils.isnamedtuple(base) for base in o.__bases__):
                logging.log(2, 'Whitelisted: %s: named tuple', o)
                return True
            logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
            return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
예제 #6
def converted_call(f,
  """Compiles a function call inline.

  For internal use only.

  Note: The argument list is optimized for readability of generated code, which
  may look like this:

    ag__.converted_call(f, (arg1, arg2), None, fscope)
    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)

    f: The function to convert.
    args: Tuple, the original positional arguments of f
    kwargs: Optional[Dict], the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.
    options: Optional[converter.ConversionOptions], conversion options. If not
      specified, the value of caller_fn_scope.callopts is used. Either options
      or caller_fn_scope must be present.

    Any, the result of executing a possibly-converted `f` with the given
  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,

  if options is None:
    if caller_fn_scope is None:
      raise ValueError('either caller_fn_scope or options must have a value')
    options = caller_fn_scope.callopts

  if conversion.is_in_whitelist_cache(f, options):
    logging.log(2, 'Whitelisted %s: from cache', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if is_autograph_artifact(f):
    logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
    return _call_unconverted(f, args, kwargs, options)

  # If this is a partial, unwrap it and redo all the checks.
  if isinstance(f, functools.partial):
    new_kwargs = {}
    if f.keywords is not None:
      # Use copy to avoid mutating the underlying keywords.
      new_kwargs = f.keywords.copy()
    if kwargs is not None:
    new_args = f.args + args
    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
    return converted_call(

  if inspect_utils.isbuiltin(f):
    if f is eval:
      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
    if f is super:
      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
    if kwargs:
      return py_builtins.overload_of(f)(*args, **kwargs)
      return py_builtins.overload_of(f)(*args)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        '{} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will run as-is.'
        ' You may still apply AutoGraph before the wrapt decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs, options)

  if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
    logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
    return _call_unconverted(f, args, kwargs, options)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if inspect_utils.isconstructor(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs, options)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if any(
      f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs, options)

  # Custom ops and kernels are also permanently whitelisted.
  # See tensorflow.framework.load_library.
  if (hasattr(f, '__module__') and
      hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
    logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
    return _call_unconverted(f, args, kwargs, options)

  if not options.user_requested and conversion.is_whitelisted(f):
    return _call_unconverted(f, args, kwargs, options)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs, options)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    if inspect.ismethod(f) or inspect.isfunction(f):
      target_entity = f
      effective_args = args

      f_self = getattr(f, '__self__', None)
      if f_self is not None:
        # if isinstance(f_self, function.TfMethodTarget):
        #   f_self = f_self.target
        effective_args = (f_self,) + effective_args

    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
      # Callable objects. Dunder methods have special lookup rules, see:
      # https://docs.python.org/3/reference/datamodel.html#specialnames
      target_entity = f.__class__.__call__
      effective_args = (f,) + args

      target_entity = f
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    if not inspect.isclass(target_entity):
      if not hasattr(target_entity, '__code__'):
        logging.log(2, 'Permanently whitelisted: %s: native binding',
        return _call_unconverted(f, args, kwargs, options)
      elif (hasattr(target_entity.__code__, 'co_filename') and
            target_entity.__code__.co_filename == '<string>'):
        # TODO(mdan): __globals__['txt'] might work in Py3.
        logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
        return _call_unconverted(f, args, kwargs, options)

    program_ctx = converter.ProgramContext(
        options=options, autograph_module=tf_inspect.getmodule(converted_call))
    converted_f = conversion.convert(target_entity, program_ctx)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
      if not six.PY2:
        logging.log(2, 'KW defaults of %s : %s',
                    converted_f, converted_f.__kwdefaults__)

      if kwargs is not None:
        callargs = tf_inspect.getcallargs(converted_f, *effective_args,
        callargs = tf_inspect.getcallargs(converted_f, *effective_args)

      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  except Exception as e:  # pylint:disable=broad-except
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    if is_autograph_strict_conversion_mode():

    warning_template = (
        'AutoGraph could not transform %s and will run it as-is.\n'
        'Cause: %s\n'
        'To silence this warning, decorate the function with'
        ' @tf.autograph.experimental.do_not_convert')
    if isinstance(e, errors.UnsupportedLanguageElementError):
      # Repeating the check made upon function entry because the state might
      # have updated in the meantime.
      if not conversion.is_in_whitelist_cache(f, options):
        logging.warn(warning_template, target_entity, '', e)
      file_bug_message = (
          'Please report this to the TensorFlow team. When filing the bug, set'
          ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
          ' attach the full output.\n')
      logging.warn(warning_template, target_entity, file_bug_message, e)

    return _call_unconverted(f, args, kwargs, options)

  #with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
    if kwargs is not None:
      result = converted_f(*effective_args, **kwargs)
      result = converted_f(*effective_args)
  except Exception as e:
    _attach_metadata(e, converted_f)

  return result