示例#1
0
    def _transformed_factory(self, fn, cache_subkey, user_context,
                             extra_locals):
        """Returns the transformed function factory for a given input."""
        if self._cache.has(fn, cache_subkey):
            return self._cached_factory(fn, cache_subkey)

        with self._cache_lock:
            # Check again under lock.
            if self._cache.has(fn, cache_subkey):
                return self._cached_factory(fn, cache_subkey)

            logging.log(1, '%s is not cached for subkey %s', fn, cache_subkey)
            nodes, ctx = self._transform_function(fn, user_context)

            if logging.has_verbosity(2):
                logging.log(2, 'Transformed %s:\n\n%s\n', fn,
                            parser.unparse(nodes))

            factory = _TransformedFnFactory(ctx.info.name,
                                            fn.__code__.co_freevars,
                                            extra_locals)
            factory.create(nodes,
                           ctx.namer,
                           future_features=ctx.info.future_features)
            self._cache[fn][cache_subkey] = factory
        return factory
示例#2
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
    """Specialization of `convert_entity_to_ast` for callable functions."""

    future_features = inspect_utils.getfutureimports(f)
    node, source = parser.parse_entity(f, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    # Parsed AST should contain future imports and one function def node.

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve_entity(node, source, f)

    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
    elif do_rename:
        new_name = namer.function_name(f.__name__)
    else:
        new_name = f.__name__

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         future_features=future_features,
                                         namespace=namespace)
    context = converter.EntityContext(namer, entity_info, program_ctx,
                                      new_name)
    node = node_to_graph(node, context)

    if isinstance(node, gast.Lambda):
        node = gast.Assign(targets=[
            gast.Name(new_name,
                      ctx=gast.Store(),
                      annotation=None,
                      type_comment=None)
        ],
                           value=node)
    elif do_rename:
        node.name = new_name
    else:
        assert node.name == new_name

    return (node, ), new_name, entity_info
示例#3
0
def report_internal_error(entity, exception):
    ag_logging.log(1, 'Error transforming %s', entity, exc_info=True)
    # TODO(znado): Add external bug reporting instructions.
    raise AutoGraphError(
        'Unexpected error transforming %s. If you believe this is due to a bug,'
        ' please set the verbosity to 10 (on Linux, `export '
        'AUTOGRAPH_VERBOSITY=1`) and attach the full output when filing the bug '
        'report. Caused by %s' % (entity, exception))
示例#4
0
def report_internal_error(entity, exception):
  ag_logging.log(1, 'Error transforming %s', entity, exc_info=True)
  # TODO(znado): Add external bug reporting instructions.
  raise AutoGraphError(
      'Unexpected error transforming %s. If you believe this is due to a bug,'
      ' please set the verbosity to 10 (on Linux, `export '
      'AUTOGRAPH_VERBOSITY=10`) and attach the full output when filing the bug '
      'report. Caused by: %s' % (entity, exception))
示例#5
0
    def transform_function(self, fn, user_context):
        """Transforms a function. See GenericTranspiler.trasnform_function.

    This overload wraps the parent's `transform_function`, adding caching and
    facilities to instantiate the output as a Python object. It also
    adds facilities to make new symbols available to the generated Python code,
    visible as local variables - see `get_extra_locals`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      A tuple:
        * A function or lambda with the same signature and closure as `fn`
        * The temporary module into which the transformed function was loaded
        * The source map as a
            Dict[origin_info.LineLocation, origin_info.OriginInfo]
    """
        cache_subkey = self.get_caching_key(user_context)

        if self._cache.has(fn, cache_subkey):
            # Fast path: use a lock-free check.
            factory = self._cached_factory(fn, cache_subkey)

        else:
            with self._cache_lock:
                # Check again under lock.
                if self._cache.has(fn, cache_subkey):
                    factory = self._cached_factory(fn, cache_subkey)

                else:
                    logging.log(1, '%s is not cached for subkey %s', fn,
                                cache_subkey)
                    # TODO(mdan): Confusing overloading pattern. Fix.
                    nodes, ctx = super(PyToPy, self).transform_function(
                        fn, user_context)

                    if logging.has_verbosity(2):
                        logging.log(2, 'Transformed %s:\n\n%s\n', fn,
                                    parser.unparse(nodes))

                    factory = _PythonFnFactory(ctx.info.name,
                                               fn.__code__.co_freevars,
                                               self.get_extra_locals())
                    factory.create(nodes,
                                   ctx.namer,
                                   future_features=ctx.info.future_features)
                    self._cache[fn][cache_subkey] = factory

        transformed_fn = factory.instantiate(globals_=fn.__globals__,
                                             closure=fn.__closure__ or (),
                                             defaults=fn.__defaults__,
                                             kwdefaults=getattr(
                                                 fn, '__kwdefaults__', None))
        return transformed_fn, factory.module, factory.source_map
示例#6
0
def convert_func_to_ast(f, program_ctx, do_rename=True):
  """Specialization of `convert_entity_to_ast` for callable functions."""

  future_features = inspect_utils.getfutureimports(f)
  node, source = parser.parse_entity(f, future_features=future_features)
  logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
  # Parsed AST should contain future imports and one function def node.

  # In general, the output of inspect.getsource is inexact for lambdas because
  # it uses regex matching to adjust the exact location around the line number
  # that CPython records. Then, the entire containing line is returned, which
  # we may have trouble disambiguating. For example:
  # x, y = lambda: 1, lambda: 2
  if f.__name__ == '<lambda>':
    nodes = ast_util.find_matching_definitions(node, f)
    if len(nodes) != 1:
      raise ValueError(
          'Unable to identify source code of lambda function {}. It was'
          ' defined on this line: {}, which must contain a single lambda with'
          ' matching signature. To avoid ambiguity, define each lambda'
          ' in a separate expression.'.format(f, source))
    node, = nodes

  # TODO(znado): Place inside standard_analysis.
  origin_info.resolve(node, source, f)
  namespace = inspect_utils.getnamespace(f)
  _add_self_references(namespace, program_ctx.autograph_module)
  namer = naming.Namer(namespace)

  entity_info = transformer.EntityInfo(
      source_code=source,
      source_file='<fragment>',
      future_features=future_features,
      namespace=namespace)
  context = converter.EntityContext(namer, entity_info, program_ctx)
  try:
    node = node_to_graph(node, context)
  except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
    logging.error(1, 'Error converting %s', f, exc_info=True)
    raise errors.InternalError('conversion', e)
    # TODO(mdan): Catch and rethrow syntax errors.

  if isinstance(node, gast.Lambda):
    new_name = namer.new_symbol('tf__lambda', ())
    node = gast.Assign(
        targets=[gast.Name(new_name, gast.Store(), None)], value=node)

  elif do_rename:
    new_name = namer.function_name(f.__name__)
    node.name = new_name
  else:
    new_name = f.__name__
    assert node.name == new_name

  return (node,), new_name, entity_info
示例#7
0
def function_to_graph(f, program_ctx, arg_values, arg_types, do_rename=True):
    """Specialization of `entity_to_graph` for callable functions."""

    node, source = parser.parse_entity(f)
    logging.log(3, 'Source code of %s:\n\n%s\n', f, source)
    node = node.body[0]

    # In general, the output of inspect.getsource is inexact for lambdas because
    # it uses regex matching to adjust the exact location around the line number
    # that CPython records. Then, the entire containing line is returned, which
    # we may have trouble disambiguating. For example:
    # x, y = lambda: 1, lambda: 2
    if f.__name__ == '<lambda>':
        nodes = ast_util.find_matching_definitions(node, f)
        if len(nodes) != 1:
            raise ValueError(
                'Unable to identify source code of lambda function {}. It was'
                ' defined on this line: {}, which must contain a single lambda with'
                ' matching signature. To avoid ambiguity, define each lambda'
                ' in a separate expression.'.format(f, source))
        node, = nodes

    # TODO(znado): Place inside standard_analysis.
    origin_info.resolve(node, source, f)
    namespace = inspect_utils.getnamespace(f)
    _add_self_references(namespace, program_ctx.autograph_module)
    namer = naming.Namer(namespace)

    entity_info = transformer.EntityInfo(source_code=source,
                                         source_file='<fragment>',
                                         namespace=namespace,
                                         arg_values=arg_values,
                                         arg_types=arg_types)
    context = converter.EntityContext(namer, entity_info, program_ctx)
    try:
        node = node_to_graph(node, context)
    except (ValueError, AttributeError, KeyError, NotImplementedError) as e:
        logging.error(1, 'Error converting %s', f, exc_info=True)
        raise errors.InternalError('conversion', e)
        # TODO(mdan): Catch and rethrow syntax errors.

    if isinstance(node, gast.Lambda):
        new_name = namer.new_symbol('tf__lambda', ())
        node = gast.Assign(targets=[gast.Name(new_name, gast.Store(), None)],
                           value=node)

    elif do_rename:
        # TODO(mdan): This somewhat duplicates the renaming logic in call_trees.py
        new_name = namer.function_name(f.__name__)
        node.name = new_name
    else:
        new_name = f.__name__
        assert node.name == new_name

    return [node], new_name, namespace
示例#8
0
    def _transform_function(self, fn, user_context):
        """Performs source code transformation on a function."""
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        # In general, the output of inspect.getsource is inexact for lambdas
        # because it uses regex matching to adjust the exact location around
        # the line number that CPython records. Then, the entire containing line
        # is returned, which we may have trouble disambiguating.
        # For example:
        #   x, y = lambda: 1, lambda: 2
        is_lambda = fn.__name__ == '<lambda>'
        if is_lambda:
            nodes = ast_util.find_matching_definitions(node, fn)
            if len(nodes) != 1:
                raise ValueError(
                    'Unable to identify source code of lambda function {}.'
                    ' It was defined in this code:\n'
                    '{}\n'
                    'This code must contain a single distinguishable lambda.'
                    ' To avoid this problem, define each lambda in a separate'
                    ' expression.'.format(fn, source))
            node, = nodes

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if is_lambda:
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
  """Returns a (possibly cached) factory for the converted result of entity."""
  # The cache key is the entity's code object if it defined one, otherwise it's
  # the entity itself. Keying by the code object allows caching of functions
  # that are dynamically created e.g. in a loop.
  if hasattr(entity, '__code__'):
    key = entity.__code__
  else:
    key = entity

  # The cache subkey encompases any conversion options on which the generated
  # code may depend.
  # The cached factory includes the necessary definitions to distinguish
  # between the global and non-global free variables. For this reason, the
  # cache subkey includes the names of the free non-globals.
  subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

  with _CACHE_LOCK:
    # The cache values are _ConvertedEntityFactoryInfo objects.
    if _CACHE.has(key, subkey):
      # TODO(mdan): Check whether the module is still loaded.
      converted_entity_info = _CACHE[key][subkey]
      logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity,
                  key, subkey, converted_entity_info)
      return converted_entity_info

    logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key,
                subkey)

    nodes, converted_name, entity_info = convert_entity_to_ast(
        entity, program_ctx)

    namer = naming.Namer(entity_info.namespace)
    factory_factory_name = namer.new_symbol('create_converted_entity_factory',
                                            ())
    factory_name = namer.new_symbol('create_converted_entity', ())
    nodes = _wrap_into_dynamic_factory(nodes, converted_name,
                                       factory_factory_name, factory_name,
                                       free_nonglobal_var_names,
                                       entity_info.future_features)

    module, _, source_map = compiler.ast_to_object(
        nodes, include_source_map=True)
    module_name = module.__name__

    converted_entity_info = _ConvertedEntityFactoryInfo(
        module_name=module_name,
        converted_name=converted_name,
        factory_factory_name=factory_factory_name,
        source_map=source_map)
    _CACHE[key][subkey] = converted_entity_info
    return converted_entity_info
示例#10
0
def _convert_with_cache(entity, program_ctx, free_nonglobal_var_names):
  """Returns a (possibly cached) factory for the converted result of entity."""
  # The cache key is the entity's code object if it defined one, otherwise it's
  # the entity itself. Keying by the code object allows caching of functions
  # that are dynamically created e.g. in a loop.
  if hasattr(entity, '__code__'):
    key = entity.__code__
  else:
    key = entity

  # The cache subkey encompases any conversion options on which the generated
  # code may depend.
  # The cached factory includes the necessary definitions to distinguish
  # between the global and non-global free variables. For this reason, the
  # cache subkey includes the names of the free non-globals.
  subkey = (program_ctx.options, frozenset(free_nonglobal_var_names))

  # The cache values are _ConvertedEntityFactoryInfo objects.
  if _CACHE.has(key, subkey):
    # TODO(mdan): Check whether the module is still loaded.
    converted_entity_info = _CACHE[key][subkey]
    logging.log(3, 'Cache hit for entity %s key %s subkey %s: %s', entity, key,
                subkey, converted_entity_info)
    return converted_entity_info

  logging.log(1, 'Entity %s is not cached for key %s subkey %s', entity, key,
              subkey)

  nodes, converted_name, entity_info = convert_entity_to_ast(
      entity, program_ctx)

  namer = naming.Namer(entity_info.namespace)
  factory_factory_name = namer.new_symbol('create_converted_entity_factory', ())
  factory_name = namer.new_symbol('create_converted_entity', ())
  nodes = _wrap_into_dynamic_factory(
      nodes, converted_name, factory_factory_name, factory_name,
      free_nonglobal_var_names, entity_info.future_features)

  module, _, source_map = compiler.ast_to_object(nodes, include_source_map=True)
  module_name = module.__name__

  converted_entity_info = _ConvertedEntityFactoryInfo(
      module_name=module_name,
      converted_name=converted_name,
      factory_factory_name=factory_factory_name,
      source_map=source_map)
  _CACHE[key][subkey] = converted_entity_info
  return converted_entity_info
示例#11
0
def _attach_metadata(e, f, converted):
  """Augments an error with the metadata necessary for rewrite."""
  if hasattr(e, 'ag_pass_through'):
    return

  metadata = getattr(e, 'ag_error_metadata', None)
  source_map = f.ag_source_map if converted else {}

  if metadata is None:
    logging.log(
        1, 'Caught error in %s (converted=%s)', f, converted, exc_info=True)
    message = '{}: {}'.format(e.__class__.__name__, e)
  else:
    message = None

  cause_tb = traceback.extract_tb(sys.exc_info()[2])[1:]
  e.ag_error_metadata = _ErrorMetadata(cause_tb, metadata, message, source_map)
示例#12
0
    def transform_function(self, fn, user_context):
        """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Any. By default it returns the output of transform_ast.
    """
        future_features = inspect_utils.getfutureimports(fn)
        node, source = parser.parse_entity(fn, future_features=future_features)
        logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

        origin_info.resolve_entity(node, source, fn)

        namespace = inspect_utils.getnamespace(fn)
        namer = naming.Namer(namespace)
        new_name = namer.new_symbol(self.get_transformed_name(node), ())
        entity_info = transformer.EntityInfo(name=new_name,
                                             source_code=source,
                                             source_file='<fragment>',
                                             future_features=future_features,
                                             namespace=namespace)
        context = transformer.Context(entity_info, namer, user_context)

        node = self._erase_arg_defaults(node)
        node = self.transform_ast(node, context)

        if isinstance(node, gast.Lambda):
            node = gast.Assign(targets=[
                gast.Name(new_name,
                          ctx=gast.Store(),
                          annotation=None,
                          type_comment=None)
            ],
                               value=node)
        else:
            node.name = new_name

        return node, context
示例#13
0
 def guarded_extra_test():
   extra_test_result = extra_test()
   try:
     # Note: Using try/except and not tensor_util.is_tf_type to avoid
     # performance degradation.
     return bool(extra_test_result)
   except errors_impl.OperatorNotAllowedInGraphError as e:
     ag_logging.log(
         1,
         'Caught error while evaluating loop stop condition',
         exc_info=True)
     # TODO(mdan): We can pass the location of extra_test and show it here.
     raise NotImplementedError(
         'break and return statements which depend on a TF condition are not'
         ' supported in Python for loops. Did you intend to make it a TF'
         ' loop?\nSee '
         'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
         'python/autograph/g3doc/reference/limitations.md'
         '#consistency-of-control-flow-types for more info.') from e
示例#14
0
def is_unsupported(o):
    """Checks whether an entity is supported by AutoGraph at all."""

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(o, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(o, 'wrapt', 'BoundFunctionWrapper')):
        logging.warning(
            '{} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will run as-is.'
            ' You may still apply AutoGraph before the wrapt decorator.'.
            format(o))
        logging.log(2, 'Permanently allowed: %s: wrapt decorated', o)
        return True

    if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently allowed: %s: lru_cache', o)
        return True

    # Constructors are permanently allowed.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if inspect_utils.isconstructor(o):
        logging.log(2, 'Permanently allowed: %s: constructor', o)
        return True

    # Other built-in modules are permanently allowed.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(
            _is_of_known_loaded_module(o, m)
            for m in ('collections', 'pdb', 'copy', 'inspect', 're')):
        logging.log(2, 'Permanently allowed: %s: part of builtin module', o)
        return True

    # Custom ops and kernels are also permanently allowed.
    # See tensorflow.framework.load_library.
    if (hasattr(o, '__module__')
            and hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')):
        logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o)
        return True

    return False
示例#15
0
 def guarded_test():
   test_result = test()
   try:
     # Note: Using try/except and not tensor_util.is_tf_type to avoid
     # performance degradation.
     return bool(test_result)
   except errors_impl.OperatorNotAllowedInGraphError as e:
     ag_logging.log(
         1,
         'Caught error while evaluating while loop condition',
         exc_info=True)
     # TODO(mdan): distinguish beteen these two cases.
     raise NotImplementedError(
         'The condition of while loop started as non-Tensor, then changed to'
         ' Tensor. This may happen either because variables changed type, or'
         ' when a break or return statement inside the loop depends on a'
         ' Tensor condition. In both cases, changing to a TF loop should'
         ' remove the error.\nSee '
         'https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
         'python/autograph/g3doc/reference/limitations.md'
         '#consistency-of-control-flow-types for more info.') from e
示例#16
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    NotImplementedError: if entity is of a type that is not yet supported.
  """
    logging.log(1, 'Converting %s', o)

    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    parser.unparse(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
示例#17
0
  def transform_function(self, fn, user_context):
    """Transforms a function.

    Subclasses may override this method. The return value is opaque.

    The method receives the original AST. The result is passed as-is to the
    output of `transform`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      Tuple[Any, Any]. By default it returns the output of transform_ast,
      together with a `transformer.Context` containing information about the
      transformation process.
    """
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    result = self.transform_ast(node, context)

    return result, context
示例#18
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
      parameters.
    arg_types: A dict containing type hints for symbols like function
      parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = function_to_graph(o, program_ctx,
                                                     arg_values, arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'python/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
示例#19
0
  def _transform_function(self, fn, user_context):
    """Performs source code transformation on a function."""
    future_features = inspect_utils.getfutureimports(fn)
    node, source = parser.parse_entity(fn, future_features=future_features)
    logging.log(3, 'Source code of %s:\n\n%s\n', fn, source)

    origin_info.resolve_entity(node, source, fn)

    namespace = inspect_utils.getnamespace(fn)
    namer = naming.Namer(namespace)
    new_name = namer.new_symbol(self.get_transformed_name(node), ())
    entity_info = transformer.EntityInfo(
        name=new_name,
        source_code=source,
        source_file='<fragment>',
        future_features=future_features,
        namespace=namespace)
    context = transformer.Context(entity_info, namer, user_context)

    node = self._erase_arg_defaults(node)
    node = self.transform_ast(node, context)

    if isinstance(node, gast.Lambda):
      node = gast.Assign(
          targets=[
              gast.Name(
                  new_name,
                  ctx=gast.Store(),
                  annotation=None,
                  type_comment=None)
          ],
          value=node)
    else:
      node.name = new_name

    return node, context
示例#20
0
def _log_callargs(f, args, kwargs):
    """Logging helper."""
    logging.log(2, 'Defaults of %s : %s', f, f.__defaults__)
    logging.log(2, 'KW defaults of %s : %s', f, f.__kwdefaults__)

    if kwargs is not None:
        callargs = tf_inspect.getcallargs(f, *args, **kwargs)
    else:
        callargs = tf_inspect.getcallargs(f, *args)

    formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                   for k, v in callargs.items())
    logging.log(2, 'Calling %s with\n%s\n', f, formatted_callargs)
示例#21
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
示例#22
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif hasattr(o, '__class__'):
    # Note: this should only be raised when attempting to convert the object
    # directly. converted_call should still support it.
    raise NotImplementedError(
        'cannot convert entity "{}": object conversion is not yet'
        ' supported.'.format(o))
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
示例#23
0
def convert_entity_to_ast(o, program_ctx):
    """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
    elif tf_inspect.isfunction(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif tf_inspect.ismethod(o):
        nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
    elif hasattr(o, '__class__'):
        # Note: this should only be raised when attempting to convert the object
        # directly. converted_call should still support it.
        raise NotImplementedError(
            'cannot convert entity "{}": object conversion is not yet'
            ' supported.'.format(o))
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(nodes))
    if logging.has_verbosity(4):
        for n in nodes:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    return nodes, name, entity_info
示例#24
0
def is_whitelisted_for_graph(o, check_call_override=True):
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.

  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    if hasattr(m, '__name__'):
        # Builtins typically have unnamed modules.
        for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
            if m.__name__.startswith(prefix + '.') or m.__name__ == prefix:
                logging.log(2, 'Whitelisted: %s: name starts with "%s"', o,
                            prefix)
                return True

    if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
        logging.log(2, 'Whitelisted: %s: already converted', o)
        return True

    if tf_inspect.isgeneratorfunction(o):
        logging.warn(
            'Entity {} appears to be a generator function. It will not be converted'
            ' by AutoGraph.'.format(o), 1)
        logging.log(2,
                    'Whitelisted: %s: generator functions are not converted',
                    o)
        return True

    if check_call_override and hasattr(o, '__call__'):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(
                o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            is_call_override = (o.__name__ == '__call__')
            if is_whitelisted_for_graph(
                    owner_class, check_call_override=not is_call_override):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.warn(
                'Entity {} looks like a namedtuple subclass. Its constructor will'
                ' not be converted by AutoGraph, but if it has any custom methods,'
                ' those will be.'.format(o), 1)
        logging.log(2, 'Whitelisted: %s: named tuple', o)
        return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
示例#25
0
def converted_call(f, owner, options, args, kwargs):
    """Compiles a function call inline. For internal use only."""
    if owner is not None:
        if not isinstance(f, str):
            raise ValueError(
                'When owner is specified, the function name must be specified as'
                ' a string: {}'.format(f))
        owner_attr = f

        # Special case when the owner is a 'super' object. In that case lookups of
        # dynamic attributes won't work. See
        # inspect_utils.SuperWrapperForDynamicAttrs.
        if isinstance(owner, super):
            owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

        f = getattr(owner, f)

    if logging.has_verbosity(1):
        if owner is not None:
            composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
        else:
            composite_desc = ''

        logging.log(1, 'Converted call: %s %s\n    args: %s\n    kwargs: %s\n',
                    f, composite_desc, args, kwargs)

    if inspect_utils.isbuiltin(f):
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    # TODO(mdan): Clean up the naming inconsistency.
    if hasattr(f, 'autograph_info__') or hasattr(f, '__ag_compiled'):
        logging.log(2, 'Permanently whitelisted: %s: already converted', f)
        return _call_unconverted(f, args, kwargs)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            'Entity {} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will be called without transformation.'
            ' You may however apply AutoGraph before the decorator.'.format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs)

    if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
        return _call_unconverted(f, args, kwargs)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, inspect, re)):
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs)

    # Custom ops and kernels are also permanently whitelisted.
    # See tensorflow.framework.load_library.
    if (hasattr(f, '__module__')
            and hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
        logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
        return _call_unconverted(f, args, kwargs)

    if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        if not tf_inspect.isclass(target_entity):
            if not hasattr(target_entity, '__code__'):
                logging.log(2, 'Permanently whitelisted: %s: native binding',
                            target_entity)
                return _call_unconverted(f, args, kwargs)
            elif (hasattr(target_entity.__code__, 'co_filename')
                  and target_entity.__code__.co_filename == '<string>'):
                # TODO(mdan): __globals__['txt'] might work in Py3.
                logging.log(
                    2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                    target_entity)
                return _call_unconverted(f, args, kwargs)

        converted_f = to_graph(
            target_entity,
            recursive=options.recursive,
            experimental_optional_features=options.optional_features)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if six.PY3:
                logging.log(2, 'KW defaults of %s : %s', converted_f,
                            converted_f.__kwdefaults__)

            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)

            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise
        logging.warn(
            'Entity %s could not be transformed and will be executed as-is.'
            ' Please report this to the AutoGraph team. When filing the bug, set'
            ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
            ' attach the full output. Cause: %s', target_entity, e)
        return _call_unconverted(f, args, kwargs)

    with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
        try:
            if kwargs is not None:
                result = converted_f(*effective_args, **kwargs)
            else:
                result = converted_f(*effective_args)
        except Exception as e:
            _attach_metadata(e, converted_f, True)
            raise

    return result
示例#26
0
def converted_call(f,
                   args,
                   kwargs,
                   caller_fn_scope=None,
                   options=None):
  """Compiles a function call inline.

  For internal use only.

  Note: The argument list is optimized for readability of generated code, which
  may look like this:

    ag__.converted_call(f, (arg1, arg2), None, fscope)
    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)

  Args:
    f: The function to convert.
    args: Tuple, the original positional arguments of f
    kwargs: Optional[Dict], the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.
    options: Optional[converter.ConversionOptions], conversion options. If not
      specified, the value of caller_fn_scope.callopts is used. Either options
      or caller_fn_scope must be present.

  Returns:
    Any, the result of executing a possibly-converted `f` with the given
      arguments.
  """
  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
              kwargs)

  if options is None:
    if caller_fn_scope is None:
      raise ValueError('either caller_fn_scope or options must have a value')
    options = caller_fn_scope.callopts

  if conversion.is_in_whitelist_cache(f, options):
    logging.log(2, 'Whitelisted %s: from cache', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if is_autograph_artifact(f):
    logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
    return _call_unconverted(f, args, kwargs, options)

  # If this is a partial, unwrap it and redo all the checks.
  if isinstance(f, functools.partial):
    new_kwargs = {}
    if f.keywords is not None:
      # Use copy to avoid mutating the underlying keywords.
      new_kwargs = f.keywords.copy()
    if kwargs is not None:
      new_kwargs.update(kwargs)
    new_args = f.args + args
    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
                new_kwargs)
    return converted_call(
        f.func,
        new_args,
        new_kwargs,
        caller_fn_scope=caller_fn_scope,
        options=options)

  if inspect_utils.isbuiltin(f):
    if f is eval:
      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
    if f is super:
      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
    if kwargs:
      return py_builtins.overload_of(f)(*args, **kwargs)
    else:
      return py_builtins.overload_of(f)(*args)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        '{} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will run as-is.'
        ' You may still apply AutoGraph before the wrapt decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs, options)

  if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
    logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
    return _call_unconverted(f, args, kwargs, options)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if inspect_utils.isconstructor(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs, options)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if any(
      f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs, options)

  # Custom ops and kernels are also permanently whitelisted.
  # See tensorflow.framework.load_library.
  if (hasattr(f, '__module__') and
      hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
    logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
    return _call_unconverted(f, args, kwargs, options)

  if not options.user_requested and conversion.is_whitelisted(f):
    return _call_unconverted(f, args, kwargs, options)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs, options)

  try:
    if inspect.ismethod(f) or inspect.isfunction(f):
      target_entity = f
      effective_args = args

      f_self = getattr(f, '__self__', None)
      if f_self is not None:
        if isinstance(f_self, function.TfMethodTarget):
          f_self = f_self.target
        effective_args = (f_self,) + effective_args

    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
      # Callable objects. Dunder methods have special lookup rules, see:
      # https://docs.python.org/3/reference/datamodel.html#specialnames
      # TODO(mdan): Recurse into converted_call to simplify other verifications.
      # This should be handled in the same way as partials.
      target_entity = f.__class__.__call__
      effective_args = (f,) + args

    else:
      target_entity = f
      raise NotImplementedError('unknown callable type "%s"' % type(f))

  except Exception as e:  # pylint:disable=broad-except
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    if is_autograph_strict_conversion_mode():
      raise
    return _fall_back_unconverted(f, args, kwargs, options, e)

  if not hasattr(target_entity, '__code__'):
    logging.log(2, 'Permanently whitelisted: %s: native binding',
                target_entity)
    return _call_unconverted(f, args, kwargs, options)
  elif (hasattr(target_entity.__code__, 'co_filename') and
        target_entity.__code__.co_filename == '<string>'):
    # TODO(mdan): __globals__['txt'] might work in Py3.
    logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                target_entity)
    return _call_unconverted(f, args, kwargs, options)

  try:
    program_ctx = converter.ProgramContext(
        options=options, autograph_module=tf_inspect.getmodule(converted_call))
    converted_f = conversion.convert(target_entity, program_ctx)
    if logging.has_verbosity(2):
      _log_callargs(converted_f, effective_args, kwargs)
  except Exception as e:  # pylint:disable=broad-except
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    if is_autograph_strict_conversion_mode():
      raise
    return _fall_back_unconverted(f, args, kwargs, options, e)

  with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
    try:
      if kwargs is not None:
        result = converted_f(*effective_args, **kwargs)
      else:
        result = converted_f(*effective_args)
    except Exception as e:
      _attach_metadata(e, converted_f)
      raise

  return result
示例#27
0
 def _cached_factory(self, fn, cache_subkey):
   cached_factory = self._cache[fn][cache_subkey]
   logging.log(3, 'Cache hit for %s subkey %s: %s', fn, cache_subkey,
               cached_factory)
   return cached_factory
示例#28
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(node))

    if program_ctx.options.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns
示例#29
0
def is_whitelisted_for_graph(o):
  """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)
  if not hasattr(m, '__name__'):
    # Note: typically it's builtins that fall in this category. Builtins will
    # be handled by specific code that follows this screening layer.
    logging.log(2, '%s is NOT whitelisted: unknown module name', o)
    return False

  for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
    if m.__name__.startswith(prefix):
      logging.log(2, '%s is whitelisted: name starts with "%s"', o, prefix)
      return True

  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
    logging.log(2, '%s is whitelisted: already converted', o)
    return True

  if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o) and
      hasattr(o, '__call__') and hasattr(o, '__class__')):
    # Callable objects: whitelisted if their __call__ method is.
    call_whitelisted = is_whitelisted_for_graph(o.__call__)
    if call_whitelisted:
      logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
      return call_whitelisted

  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
                    owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.warn_first_n(
          'Entity {} looks like a namedtuple subclass. If it has any custom'
          ' methods, they will not be converted by AutoGraph.'.format(o), 1)
    logging.log(2, '%s is whitelisted: named tuple', o)
    return True

  logging.log(2, '%s is NOT whitelisted', o)
  return False
示例#30
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  logging.log(1,
              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
              f, owner, args, kwargs)

  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if inspect_utils.isbuiltin(f):
    return py_builtins.overload_of(f)(*args, **kwargs)

  # TODO(b/122265385): Remove this bypass.
  if ('wrapt' in sys.modules and
      hasattr(sys.modules['wrapt'], 'FunctionWrapper') and
      isinstance(f, sys.modules['wrapt'].FunctionWrapper)):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f), 1)
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return f(*args, **kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return f(*args, **kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
      f in copy.__dict__.values()):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return f(*args, **kwargs)

  # TODO(mdan): This needs cleanup.
  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):

    # TODO(mdan): This may be inconsistent in certain situations.
    # If the function had already been annotated with @tf.function, it
    # may be bound to the incorrect object. It's unclear if those situations
    # are possible, but if they happen, we need to check if f is bound
    # to a shim like WeakrefSelf and unpack it.

    # Args typically include `self`, as required by the conversion process.
    # When conversion is skipped, `self` is not necessary, because the
    # original bound method is being executed. This code removes it.
    if tf_inspect.ismethod(f) and args:
      f_self = inspect_utils.getmethodself(f)
      if args[0] is f_self:
        args = args[1:]

    return f(*args, **kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return f(*args, **kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      arg_map_target = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        # If this is a method call, it may or may not include self.
        #
        # Example when self is included:
        #   converted_call(to_graph(foo.bar), foo)
        #
        # Example when self is not included:
        #   super(...).foo(args)
        #
        if owner is not None and (not args or args[0] is not owner):
          effective_args = (owner,) + args
        else:
          # When the owner is not specified, use the result of
          # inspect_utils.getmethodclass.
          # TODO(b/119246461): Make sure an owner is always specified.
          if not args or args[0] is not f_self:
            effective_args = (f_self,) + args
          else:
            effective_args = (f_self,) + args[1:]
        partial_types = (f_self,)
      else:
        effective_args = args
        partial_types = ()

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      arg_map_target = f.__init__
      effective_args = args
      partial_types = ()

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      arg_map_target = f.__call__
      effective_args = (f,) + args
      partial_types = (f.__class__,)

    else:
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
    arg_types = {}
    for name, arg in arg_values.items():
      arg_class = arg.__class__
      arg_types[name] = (arg_class.__name__, arg_class)

    # When called from within a decorator, this is the only indication that
    # the function is a method - it appears that the decorator is applied
    # before the method is bound.
    if not partial_types:
      if 'self' in arg_values:
        if tf_inspect.isclass(arg_values['self'].__class__):
          partial_types = (arg_values['self'].__class__,)
      elif 'cls' in arg_values:
        if tf_inspect.isclass(arg_values['cls']):
          partial_types = (arg_values['cls'],)

    logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
                partial_types)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=arg_values,
        arg_types=arg_types,
        experimental_optional_features=options.optional_features,
        experimental_strip_decorators=options.strip_decorators,
        experimental_verbose=options.verbose,
        experimental_partial_types=partial_types)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    logging.warn(
        'Entity %s could not be transformed and will be staged without change.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY=5. Please report this to the AutoGraph'
        ' team. Cause: %s', target_entity, e)

    return f(*args, **kwargs)

  result = converted_f(*effective_args, **kwargs)

  # The converted function's closure is simply inserted into the function's
  # module __dict__. Since modules are permanently cached, that results in
  # leaking the entire closure.
  # Normally, it's not safe to delete the module because that may release said
  # closure as well. However, in the case of converted_call we are certain the
  # function will not be executed again, so the closure should no longer be
  # needed so long as the function doesn't return any executable code.
  # TODO(mdan): Attach the closure properly, using cells.
  if all(map(_is_not_callable, nest.flatten(result))):
    del sys.modules[converted_f.__module__]

  return result
示例#31
0
def is_whitelisted(o,
                   check_call_override=True,
                   allow_namedtuple_subclass=False):
    """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
    check_call_override: Reserved for internal use. When set to `False`, it
      disables the rule according to which classes are whitelisted if their
      __call__ method is whitelisted.
    allow_namedtuple_subclass: Reserved for internal use. When `True`,
      namedtuple subclasses are not whitelisted.

  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)

    # Examples of callables that lack a __module__ property include builtins.
    if hasattr(m, '__name__'):
        for rule in config.CONVERSION_RULES:
            action = rule.get_action(m)
            if action == config.Action.CONVERT:
                logging.log(2, 'Not whitelisted: %s: %s', o, rule)
                return False
            elif action == config.Action.DO_NOT_CONVERT:
                logging.log(2, 'Whitelisted: %s: %s', o, rule)
                return True

    if tf_inspect.isgeneratorfunction(o):
        logging.warn(
            'Entity %s appears to be a generator function. It will not be converted'
            ' by AutoGraph.', o)
        logging.log(2,
                    'Whitelisted: %s: generator functions are not converted',
                    o)
        return True

    if (check_call_override and not tf_inspect.isclass(o)
            and hasattr(o, '__call__')):
        # Callable objects: whitelisted if their __call__ method is.
        # The type check avoids infinite recursion around the __call__ method
        # of function objects.
        if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__):  # pylint: disable=unidiomatic-typecheck
            logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
            return True

    owner_class = None
    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is function.TfMethodTarget:
            owner_class = o.__self__.target_class
        if owner_class is not None:
            if issubclass(owner_class, unittest.TestCase):
                logging.log(2, 'Whitelisted: %s: method of TestCase subclass',
                            o)
                return True

            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted(owner_class,
                              check_call_override=False,
                              allow_namedtuple_subclass=True):
                logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if allow_namedtuple_subclass:
            if not any(
                    inspect_utils.isnamedtuple(base) for base in o.__bases__):
                logging.log(2, 'Whitelisted: %s: named tuple', o)
                return True
        else:
            logging.log(2, 'Whitelisted: %s: named tuple or subclass', o)
            return True

    logging.log(2, 'Not whitelisted: %s: default rule', o)
    return False
示例#32
0
def create_source_map(nodes, code, filepath):
  """Creates a source map between an annotated AST and the code it compiles to.

  Note: this function assumes nodes nodes, code and filepath correspond to the
  same code.

  Args:
    nodes: Iterable[ast.AST, ...], one or more AST modes.
    code: Text, the source code in which nodes are found.
    filepath: Text

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
  for node in reparsed_nodes:
    resolve(node, code, filepath, node.lineno, node.col_offset)

  source_map = {}

  try:
    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
      # Note: generated code might not be mapped back to its origin.
      # TODO(mdan): Generated code should always be mapped to something.
      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
      if origin_info is None or final_info is None:
        continue

      # Note: the keys are by line only, excluding the column offset.
      line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)

      existing_origin = source_map.get(line_loc)
      if existing_origin is not None:
        # Overlaps may exist because of child nodes, but almost never to
        # different line locations. Exception make decorated functions, where
        # both lines are mapped to the same line in the AST.

        # Line overlaps: keep bottom node.
        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
          if existing_origin.loc.lineno >= origin_info.loc.lineno:
            continue

        # In case of column overlaps, keep the leftmost node.
        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
          continue

      source_map[line_loc] = origin_info

  except ValueError:
    if logging.has_verbosity(3):
      for n, rn in zip(nodes, reparsed_nodes):
        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
        diff = difflib.context_diff(
            nodes_str.split('\n'),
            reparsed_nodes_str.split('\n'),
            fromfile='Original nodes',
            tofile='Reparsed nodes',
            n=7)
        diff = '\n'.join(diff)
        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
    raise

  return source_map
示例#33
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
示例#34
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  logging.log(1,
              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
              f, owner, args, kwargs)

  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if inspect_utils.isbuiltin(f):
    return py_builtins.overload_of(f)(*args, **kwargs)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  # Note: TF linter disallows importing inspect.
  if any(f in m.__dict__.values()
         for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs)

  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
    return _call_unconverted(f, args, kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      arg_map_target = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        # If this is a method call, it may or may not include self.
        #
        # Example when self is included:
        #   converted_call(to_graph(foo.bar), foo)
        #
        # Example when self is not included:
        #   super(...).foo(args)
        #
        if owner is not None and (not args or args[0] is not owner):
          effective_args = (owner,) + args
        else:
          # When the owner is not specified, use the result of
          # inspect_utils.getmethodclass.
          # TODO(b/119246461): Make sure an owner is always specified.
          if not args or args[0] is not f_self:
            effective_args = (f_self,) + args
          else:
            effective_args = (f_self,) + args[1:]
        partial_types = (f_self,)
      else:
        effective_args = args
        partial_types = ()

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      arg_map_target = f.__init__
      effective_args = args
      partial_types = ()

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      arg_map_target = f.__call__
      effective_args = (f,) + args
      partial_types = (f.__class__,)

    else:
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
    arg_types = {}
    for name, arg in arg_values.items():
      arg_class = arg.__class__
      arg_types[name] = (arg_class.__name__, arg_class)

    # When called from within a decorator, this is the only indication that
    # the function is a method - it appears that the decorator is applied
    # before the method is bound.
    if not partial_types:
      if 'self' in arg_values:
        if tf_inspect.isclass(arg_values['self'].__class__):
          partial_types = (arg_values['self'].__class__,)
      elif 'cls' in arg_values:
        if tf_inspect.isclass(arg_values['cls']):
          partial_types = (arg_values['cls'],)

    logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
                partial_types)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=arg_values,
        arg_types=arg_types,
        experimental_optional_features=options.optional_features,
        experimental_strip_decorators=options.strip_decorators,
        experimental_verbose=options.verbose,
        experimental_partial_types=partial_types)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    logging.warn(
        'Entity %s could not be transformed and will be staged without change.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
        ' AutoGraph team. Cause: %s', target_entity, e)

    return _call_unconverted(f, args, kwargs)

  result = converted_f(*effective_args, **kwargs)

  # The converted function's closure is simply inserted into the function's
  # module __dict__. Since modules are permanently cached, that results in
  # leaking the entire closure.
  # Normally, it's not safe to delete the module because that may release said
  # closure as well. However, in the case of converted_call we are certain the
  # function will not be executed again, so the closure should no longer be
  # needed so long as the function doesn't return any executable code.
  # TODO(mdan): Attach the closure properly, using cells.
  if all(map(_is_not_callable, nest.flatten(result))):
    del sys.modules[converted_f.__module__]

  return result
def create_source_map(nodes, code, filename, indices_in_code):
    """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse_str(code)
    reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]
    for node in reparsed_nodes:
        resolve(node, code)

    result = {}

    try:
        for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
            # Note: generated code might not be mapped back to its origin.
            # TODO(mdan): Generated code should always be mapped to something.
            origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
            final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
            if origin_info is None or final_info is None:
                continue

            line_loc = LineLocation(filename, final_info.loc.lineno)

            existing_origin = result.get(line_loc)
            if existing_origin is not None:
                # Overlaps may exist because of child nodes, but almost never to
                # different line locations. Exception make decorated functions, where
                # both lines are mapped to the same line in the AST.

                # Line overlaps: keep bottom node.
                if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                    if existing_origin.loc.lineno >= origin_info.loc.lineno:
                        continue

                # In case of overlaps, keep the leftmost node.
                if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                    continue

            result[line_loc] = origin_info
    except ValueError:
        if logging.has_verbosity(3):
            for n, rn in zip(nodes, reparsed_nodes):
                nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
                reparsed_nodes_str = pretty_printer.fmt(rn,
                                                        color=False,
                                                        noanno=True)
                diff = difflib.context_diff(nodes_str.split('\n'),
                                            reparsed_nodes_str.split('\n'),
                                            fromfile='Original nodes',
                                            tofile='Reparsed nodes',
                                            n=7)
                diff = '\n'.join(diff)
                logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
        raise

    return result
示例#36
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(node))
  if logging.has_verbosity(4):
    for n in node:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns
示例#37
0
def is_whitelisted_for_graph(o):
  """Checks whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.

  Returns:
    Boolean
  """
  # TODO(b/120224672): Fix this.
  if isinstance(o, functools.partial):
    # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
    # functools.partial objects do not have a __module__ attribute.
    m = functools
  else:
    m = tf_inspect.getmodule(o)

  if hasattr(m, '__name__'):
    # Builtins typically have unnamed modules.
    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
      if m.__name__.startswith(prefix):
        logging.log(2, 'Whitelisted: %s: name starts with "%s"', o, prefix)
        return True

    # Temporary -- whitelist tensorboard modules.
    # TODO(b/122731813): Remove.
    if m.__name__ == 'tensorboard' or '.tensorboard' in m.__name__:
      logging.log(2, 'Whitelisted: %s: name contains "tensorboard"', o)
      return True

  if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
    logging.log(2, 'Whitelisted: %s: already converted', o)
    return True

  if tf_inspect.isgeneratorfunction(o):
    logging.warn(
        'Entity {} appears to be a generator function. It will not be converted'
        ' by AutoGraph.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: generator functions are not converted', o)
    return True

  if hasattr(o, '__call__'):
    # Callable objects: whitelisted if their __call__ method is.
    # The type check avoids infinite recursion around the __call__ method
    # of function objects.
    if (type(o) != type(o.__call__)) and is_whitelisted_for_graph(o.__call__):  # pylint: disable=unidiomatic-typecheck
      logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o)
      return True

  owner_class = None
  if tf_inspect.ismethod(o):
    # Methods of whitelisted classes are also whitelisted, even if they are
    # bound via user subclasses.
    #
    # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
    # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
    # whitelisted.
    #
    #   class Custom(tf.Foo):
    #     pass
    #
    #   baz = Custom()
    #
    # For the example above, if `Custom` did overload `bar`, then it would no
    # longer be whitelisted.

    owner_class = inspect_utils.getmethodclass(o)
    if owner_class is not None:
      if issubclass(owner_class, unittest.TestCase):
        logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o)
        return True

      owner_class = inspect_utils.getdefiningclass(o, owner_class)
      if is_whitelisted_for_graph(owner_class):
        logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o,
                    owner_class)
        return True

  if inspect_utils.isnamedtuple(o):
    # Due to the way they're constructed, namedtuple types cannot be converted
    # because they don't expose source code. But we assume they are safe for
    # graph mode since they are just containers.
    if tf_inspect.isclass(o) and len(o.__bases__) > 1:
      logging.warn(
          'Entity {} looks like a namedtuple subclass. Its constructor will'
          ' not be converted by AutoGraph, but if it has any custom methods,'
          ' those will be.'.format(o), 1)
    logging.log(2, 'Whitelisted: %s: named tuple', o)
    return True

  logging.log(2, 'Not whitelisted: %s: default rule', o)
  return False
示例#38
0
def create_source_map(nodes, code, filename):
  """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
  for node in reparsed_nodes:
    resolve(node, code)

  result = {}

  try:
    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
      # Note: generated code might not be mapped back to its origin.
      # TODO(mdan): Generated code should always be mapped to something.
      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
      if origin_info is None or final_info is None:
        continue

      line_loc = LineLocation(filename, final_info.loc.lineno)

      existing_origin = result.get(line_loc)
      if existing_origin is not None:
        # Overlaps may exist because of child nodes, but almost never to
        # different line locations. Exception make decorated functions, where
        # both lines are mapped to the same line in the AST.

        # Line overlaps: keep bottom node.
        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
          if existing_origin.loc.lineno >= origin_info.loc.lineno:
            continue

        # In case of overlaps, keep the leftmost node.
        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
          continue

      result[line_loc] = origin_info
  except ValueError:
    if logging.has_verbosity(3):
      for n, rn in zip(nodes, reparsed_nodes):
        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
        diff = difflib.context_diff(
            nodes_str.split('\n'),
            reparsed_nodes_str.split('\n'),
            fromfile='Original nodes',
            tofile='Reparsed nodes',
            n=7)
        diff = '\n'.join(diff)
        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
    raise

  return result
示例#39
0
def to_graph(entity,
             recursive=True,
             arg_values=None,
             arg_types=None,
             experimental_optional_features=converter.Feature.ALL,
             experimental_strip_decorators=None,
             experimental_verbose=converter.Verbosity.BRIEF,
             experimental_partial_types=None):
  """Converts a Python entity into a TensorFlow graph.

  Also see: `tf.autograph.to_code`, `tf.function`.

  Unlike `tf.function`, `to_graph` is a low-level transpiler that converts
  Python code to TensorFlow graph code. It does not implement any caching,
  variable management or create any actual ops, and is best used where greater
  control over the generated TensorFlow graph is desired. Another difference
  from `tf.function` is that `to_graph` will not wrap the graph into a
  TensorFlow function or a Python callable. Internally, `tf.function` uses
  `to_graph`.

  _Example Usage_

  ```python
    def foo(x):
      if x > 0:
        y = x * x
      else:
        y = -x
      return y

    converted_foo = to_graph(foo)

    x = tf.constant(1)
    y = converted_foo(x)  # converted_foo is a TensorFlow Op-like.
    assert is_tensor(y)
  ```

  Supported Python entities include:
    * functions
    * classes
    * object methods

  Functions are converted into new functions with converted code.

  Classes are converted by generating a new class whose methods use converted
  code.

  Methods are converted into unbound function that have an additional first
  argument called `self`.

  Args:
    entity: Python callable or class to convert.
    recursive: Whether to recursively convert any functions that the
      converted function may call.
    arg_values: Optional dict of value hints for symbols including
      function arguments mapping string names to actual values. For example,
      `arg_values={'a': 1}` will map the variable `a` to the value `1`.
    arg_types: Optional dict of type hints for symbols including function
      arguments. Type hints allow specifying just the type of a variable, rather
      than a specific value.
    experimental_optional_features: `None`, a tuple of, or a single
      `tf.autograph.experimental.Feature` value. Controls the use of
      optional features in the conversion process.
    experimental_strip_decorators: A tuple specifying decorators that should be
      excluded from the compiled output. By default, when converting a function
      before the decorators are applied, the compiled output will include those
      decorators.
    experimental_verbose: The level of printing verbosity to use, as a
      `tf.autograph.experimental.Verbosity` value.
    experimental_partial_types: A `set` of `type` values, reserved for internal
      use.

  Returns:
    Same as `entity`, the converted Python function or class.

  Raises:
    ValueError: If the entity could not be converted.
  """
  try:
    if experimental_strip_decorators is None:
      experimental_strip_decorators = ()
    experimental_strip_decorators += (convert, do_not_convert, converted_call)

    program_ctx = converter.ProgramContext(
        options=converter.ConversionOptions(
            recursive=recursive,
            verbose=experimental_verbose,
            strip_decorators=experimental_strip_decorators,
            optional_features=experimental_optional_features),
        partial_types=experimental_partial_types,
        autograph_module=tf_inspect.getmodule(to_graph),
        uncompiled_modules=config.DEFAULT_UNCOMPILED_MODULES)
    _, name, namespace = conversion.entity_to_graph(entity, program_ctx,
                                                    arg_values, arg_types)

    nodes = []
    for dep in reversed(program_ctx.conversion_order):
      nodes.extend(program_ctx.dependency_cache[dep])

    compiled_module, _ = compiler.ast_to_object(
        nodes,
        source_prefix=program_ctx.required_imports,
        include_source_map=True)

    # The compiled code should see everything the entry entity saw.
    # TODO(mdan): This might not work well if the call tree spans modules?
    for key, val in namespace.items():
      # Avoid overwriting entities that have been transformed.
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    for key, val in program_ctx.additional_symbols.items():
      if key not in compiled_module.__dict__:
        compiled_module.__dict__[key] = val
    compiled = getattr(compiled_module, name)

    if hasattr(entity, '__defaults__'):
      logging.log(3, 'Default args mapping: %s has: %s', entity,
                  entity.__defaults__)
      compiled.__defaults__ = entity.__defaults__
    else:
      logging.log(3, 'Default args mapping: %s has no __defaults__', entity)

    logging.log(3, 'Namespace of %s includes: %s', compiled,
                compiled_module.__dict__.keys())

    if hasattr(compiled, '__globals__'):
      # Remove self to avoid circular references. This will probably only work
      # so long as the function is not reentrant.
      del compiled.__globals__[name]

    # Need this so the source_mapping attribute is available for the context
    # manager to access for runtime errors.
    #
    # Note that compiler.ast_to_object attaches the source map 'ag_source_map__'
    # symbol to the compiled module.
    # TODO(mdan): Record this statically in the generated code.
    # TODO(mdan): Rename this attribute to 'autograph_info__'
    source_map_attribute_name = 'ag_source_map'
    if getattr(compiled, source_map_attribute_name, None) is not None:
      # TODO(znado): change input problem errors into TransformError
      raise ValueError('cannot convert %s because is has an attribute '
                       '"%s", which is reserved for AutoGraph.' %
                       (compiled, source_map_attribute_name))
    setattr(compiled, source_map_attribute_name,
            compiled_module.__dict__['ag_source_map__'])

    return compiled
  except (ValueError, AttributeError, KeyError, NameError, AssertionError) as e:
    errors.report_internal_error(entity, e)
示例#40
0
def converted_call(f, owner, options, args, kwargs):
    """Compiles a function call inline. For internal use only."""
    logging.log(
        1, 'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n', f,
        owner, args, kwargs)

    if owner is not None:
        if not isinstance(f, str):
            raise ValueError(
                'When owner is specified, the function name must be specified as'
                ' a string: {}'.format(f))

        # Special case when the owner is a 'super' object. In that case lookups of
        # dynamic attributes won't work. See
        # inspect_utils.SuperWrapperForDynamicAttrs.
        if isinstance(owner, super):
            owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

        f = getattr(owner, f)

    if inspect_utils.isbuiltin(f):
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    if _is_known_loaded_type(f, 'weakref', 'ref'):
        logging.log(2, 'Permanently whitelisted: %s: weakref', f)
        return _call_unconverted(f, args, kwargs)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            'Entity {} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will be called without transformation.'
            ' You may however apply AutoGraph before the decorator.'.format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    # Note: TF linter disallows importing inspect.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs)

    if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        converted_f = to_graph(
            target_entity,
            recursive=options.recursive,
            arg_values=None,
            arg_types=None,
            experimental_optional_features=options.optional_features)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)
            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    # TODO(mdan): Reduce this list.
    except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
            KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
            ValueError, IOError) as e:

        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)

        if is_autograph_strict_conversion_mode():
            raise

        logging.warn(
            'Entity %s could not be transformed and will be executed as-is.'
            ' Some features (e.g. tensor-dependent conditionals and loops) may not'
            ' work as expected.'
            ' Error details can be found in the logs when running with the env'
            ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
            ' AutoGraph team. Cause: %s', target_entity, e)

        return _call_unconverted(f, args, kwargs)

    if kwargs is not None:
        result = converted_f(*effective_args, **kwargs)
    else:
        result = converted_f(*effective_args)

    return result
示例#41
0
def is_whitelisted_for_graph(o):
    """Check whether an entity is whitelisted for use in graph mode.

  Examples of whitelisted entities include all members of the tensorflow
  package.

  Args:
    o: A Python entity.
  Returns:
    Boolean
  """
    # TODO(b/120224672): Fix this.
    if isinstance(o, functools.partial):
        # tf_inspect.getmodule(functools.partial(...)) otherwise returns None since
        # functools.partial objects do not have a __module__ attribute.
        m = functools
    else:
        m = tf_inspect.getmodule(o)
    if not hasattr(m, '__name__'):
        # Note: typically it's builtins that fall in this category. Builtins will
        # be handled by specific code that follows this screening layer.
        logging.log(2, '%s is NOT whitelisted: unknown module name', o)
        return False

    for prefix, in config.DEFAULT_UNCOMPILED_MODULES:
        if m.__name__.startswith(prefix):
            logging.log(2, '%s is whitelisted: name starts with "%s"', o,
                        prefix)
            return True

    if hasattr(o, 'autograph_info__') or hasattr(o, '__ag_compiled'):
        logging.log(2, '%s is whitelisted: already converted', o)
        return True

    if (not inspect_utils.isweakrefself(o) and not tf_inspect.isclass(o)
            and hasattr(o, '__call__') and hasattr(o, '__class__')):
        # Callable objects: whitelisted if their __call__ method is.
        call_whitelisted = is_whitelisted_for_graph(o.__call__)
        if call_whitelisted:
            logging.log(2, '%s is whitelisted: object __call__ whitelisted', o)
            return call_whitelisted

    if tf_inspect.ismethod(o):
        # Methods of whitelisted classes are also whitelisted, even if they are
        # bound via user subclasses.
        #
        # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is
        # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also
        # whitelisted.
        #
        #   class Custom(tf.Foo):
        #     pass
        #
        #   baz = Custom()
        #
        # For the example above, if `Custom` did overload `bar`, then it would no
        # longer be whitelisted.

        owner_class = inspect_utils.getmethodclass(o)
        if owner_class is not None:
            owner_class = inspect_utils.getdefiningclass(o, owner_class)
            if is_whitelisted_for_graph(owner_class):
                logging.log(2, '%s is whitelisted: owner is whitelisted %s', o,
                            owner_class)
                return True

    if inspect_utils.isnamedtuple(o):
        # Due to the way they're constructed, namedtuple types cannot be converted
        # because they don't expose source code. But we assume they are safe for
        # graph mode since they are just containers.
        if tf_inspect.isclass(o) and len(o.__bases__) > 1:
            logging.warn_first_n(
                'Entity {} looks like a namedtuple subclass. If it has any custom'
                ' methods, they will not be converted by AutoGraph.'.format(o),
                1)
        logging.log(2, '%s is whitelisted: named tuple', o)
        return True

    logging.log(2, '%s is NOT whitelisted', o)
    return False
示例#42
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))
    owner_attr = f

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if logging.has_verbosity(1):
    if owner is not None:
      composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
    else:
      composite_desc = ''

    logging.log(1,
                'Converted call: %s %s\n    args: %s\n    kwargs: %s\n',
                f, composite_desc, args, kwargs)

  if inspect_utils.isbuiltin(f):
    if kwargs:
      return py_builtins.overload_of(f)(*args, **kwargs)
    else:
      return py_builtins.overload_of(f)(*args)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs)

  if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
    logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
    return _call_unconverted(f, args, kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if any(f in m.__dict__.values() for m in (collections, pdb, copy, inspect)):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs)

  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
    return _call_unconverted(f, args, kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      if kwargs is not None:
        new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        effective_args = (f_self,) + args
      else:
        effective_args = args

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      effective_args = args

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      effective_args = (f,) + args

    else:
      target_entity = f
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    if (not tf_inspect.isclass(target_entity) and
        not hasattr(target_entity, '__code__')):
      logging.log(
          2, 'Permanently whitelisted: %s: native binding', target_entity)
      return _call_unconverted(f, args, kwargs)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=None,
        arg_types=None,
        experimental_optional_features=options.optional_features)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      if kwargs is not None:
        callargs = tf_inspect.getcallargs(
            converted_f, *effective_args, **kwargs)
      else:
        callargs = tf_inspect.getcallargs(converted_f, *effective_args)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:

    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)

    if is_autograph_strict_conversion_mode():
      raise

    logging.warn(
        'Entity %s could not be transformed and will be executed as-is.'
        ' Some features (e.g. tensor-dependent conditionals and loops) may not'
        ' work as expected.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
        ' AutoGraph team. Cause: %s', target_entity, e)

    return _call_unconverted(f, args, kwargs)

  if kwargs is not None:
    result = converted_f(*effective_args, **kwargs)
  else:
    result = converted_f(*effective_args)

  return result