Exemplo n.º 1
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Exemplo n.º 2
0
def convert_entity_to_ast(o, program_ctx):
  """Compile a Python entity into equivalent TensorFlow.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    nodes, name, entity_info = convert_class_to_ast(o, program_ctx)
  elif tf_inspect.isfunction(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif tf_inspect.ismethod(o):
    nodes, name, entity_info = convert_func_to_ast(o, program_ctx)
  elif hasattr(o, '__class__'):
    # Note: this should only be raised when attempting to convert the object
    # directly. converted_call should still support it.
    raise NotImplementedError(
        'cannot convert entity "{}": object conversion is not yet'
        ' supported.'.format(o))
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(nodes))
  if logging.has_verbosity(4):
    for n in nodes:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  return nodes, name, entity_info
Exemplo n.º 3
0
def converted_call(f,
                   args,
                   kwargs,
                   caller_fn_scope=None,
                   options=None):
  """Compiles a function call inline.

  For internal use only.

  Note: The argument list is optimized for readability of generated code, which
  may look like this:

    ag__.converted_call(f, (arg1, arg2), None, fscope)
    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)

  Args:
    f: The function to convert.
    args: Tuple, the original positional arguments of f
    kwargs: Optional[Dict], the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.
    options: Optional[converter.ConversionOptions], conversion options. If not
      specified, the value of caller_fn_scope.callopts is used. Either options
      or caller_fn_scope must be present.

  Returns:
    Any, the result of executing a possibly-converted `f` with the given
      arguments.
  """
  logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f, args,
              kwargs)

  if options is None:
    if caller_fn_scope is None:
      raise ValueError('either caller_fn_scope or options must have a value')
    options = caller_fn_scope.callopts

  if conversion.is_in_whitelist_cache(f, options):
    logging.log(2, 'Whitelisted %s: from cache', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
    logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f)
    return _call_unconverted(f, args, kwargs, options, False)

  if is_autograph_artifact(f):
    logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f)
    return _call_unconverted(f, args, kwargs, options)

  # If this is a partial, unwrap it and redo all the checks.
  if isinstance(f, functools.partial):
    new_kwargs = {}
    if f.keywords is not None:
      # Use copy to avoid mutating the underlying keywords.
      new_kwargs = f.keywords.copy()
    if kwargs is not None:
      new_kwargs.update(kwargs)
    new_args = f.args + args
    logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f, new_args,
                new_kwargs)
    return converted_call(
        f.func,
        new_args,
        new_kwargs,
        caller_fn_scope=caller_fn_scope,
        options=options)

  if inspect_utils.isbuiltin(f):
    if f is eval:
      return py_builtins.eval_in_original_context(f, args, caller_fn_scope)
    if f is super:
      return py_builtins.super_in_original_context(f, args, caller_fn_scope)
    if kwargs:
      return py_builtins.overload_of(f)(*args, **kwargs)
    else:
      return py_builtins.overload_of(f)(*args)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        '{} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will run as-is.'
        ' You may still apply AutoGraph before the wrapt decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs, options)

  if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
    logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
    return _call_unconverted(f, args, kwargs, options)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if inspect_utils.isconstructor(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs, options)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if any(
      f in m.__dict__.values() for m in (collections, pdb, copy, inspect, re)):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs, options)

  # Custom ops and kernels are also permanently whitelisted.
  # See tensorflow.framework.load_library.
  if (hasattr(f, '__module__') and
      hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
    logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
    return _call_unconverted(f, args, kwargs, options)

  if not options.user_requested and conversion.is_whitelisted(f):
    return _call_unconverted(f, args, kwargs, options)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs, options)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    if inspect.ismethod(f) or inspect.isfunction(f):
      target_entity = f
      effective_args = args

      f_self = getattr(f, '__self__', None)
      if f_self is not None:
        if isinstance(f_self, function.TfMethodTarget):
          f_self = f_self.target
        effective_args = (f_self,) + effective_args

    elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
      # Callable objects. Dunder methods have special lookup rules, see:
      # https://docs.python.org/3/reference/datamodel.html#specialnames
      target_entity = f.__class__.__call__
      effective_args = (f,) + args

    else:
      target_entity = f
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    if not inspect.isclass(target_entity):
      if not hasattr(target_entity, '__code__'):
        logging.log(2, 'Permanently whitelisted: %s: native binding',
                    target_entity)
        return _call_unconverted(f, args, kwargs, options)
      elif (hasattr(target_entity.__code__, 'co_filename') and
            target_entity.__code__.co_filename == '<string>'):
        # TODO(mdan): __globals__['txt'] might work in Py3.
        logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                    target_entity)
        return _call_unconverted(f, args, kwargs, options)

    program_ctx = converter.ProgramContext(
        options=options, autograph_module=tf_inspect.getmodule(converted_call))
    converted_f = conversion.convert(target_entity, program_ctx)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      if not six.PY2:
        logging.log(2, 'KW defaults of %s : %s',
                    converted_f, converted_f.__kwdefaults__)

      if kwargs is not None:
        callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                          **kwargs)
      else:
        callargs = tf_inspect.getcallargs(converted_f, *effective_args)

      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  except Exception as e:  # pylint:disable=broad-except
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    if is_autograph_strict_conversion_mode():
      raise

    warning_template = (
        'AutoGraph could not transform %s and will run it as-is.\n'
        '%s'
        'Cause: %s\n'
        'To silence this warning, decorate the function with'
        ' @tf.autograph.experimental.do_not_convert')
    if isinstance(e, errors.UnsupportedLanguageElementError):
      # Repeating the check made upon function entry because the state might
      # have updated in the meantime.
      if not conversion.is_in_whitelist_cache(f, options):
        logging.warn(warning_template, target_entity, '', e)
    else:
      file_bug_message = (
          'Please report this to the TensorFlow team. When filing the bug, set'
          ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
          ' attach the full output.\n')
      logging.warn(warning_template, target_entity, file_bug_message, e)

    return _call_unconverted(f, args, kwargs, options)

  with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
    try:
      if kwargs is not None:
        result = converted_f(*effective_args, **kwargs)
      else:
        result = converted_f(*effective_args)
    except Exception as e:
      _attach_metadata(e, converted_f)
      raise

  return result
Exemplo n.º 4
0
def converted_call(f, owner, options, args, kwargs):
    """Compiles a function call inline. For internal use only."""
    if owner is not None:
        if not isinstance(f, str):
            raise ValueError(
                'When owner is specified, the function name must be specified as'
                ' a string: {}'.format(f))
        owner_attr = f

        # Special case when the owner is a 'super' object. In that case lookups of
        # dynamic attributes won't work. See
        # inspect_utils.SuperWrapperForDynamicAttrs.
        if isinstance(owner, super):
            owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

        f = getattr(owner, f)

    if logging.has_verbosity(1):
        if owner is not None:
            composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
        else:
            composite_desc = ''

        logging.log(1, 'Converted call: %s %s\n    args: %s\n    kwargs: %s\n',
                    f, composite_desc, args, kwargs)

    if inspect_utils.isbuiltin(f):
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            'Entity {} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will be called without transformation.'
            ' You may however apply AutoGraph before the decorator.'.format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs)

    if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
        return _call_unconverted(f, args, kwargs)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, inspect)):
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs)

    if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        if (not tf_inspect.isclass(target_entity)
                and not hasattr(target_entity, '__code__')):
            logging.log(2, 'Permanently whitelisted: %s: native binding',
                        target_entity)
            return _call_unconverted(f, args, kwargs)

        converted_f = to_graph(
            target_entity,
            recursive=options.recursive,
            arg_values=None,
            arg_types=None,
            experimental_optional_features=options.optional_features)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)
            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    # TODO(mdan): Reduce this list.
    except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
            KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
            ValueError, IOError) as e:

        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)

        if is_autograph_strict_conversion_mode():
            raise

        logging.warn(
            'Entity %s could not be transformed and will be executed as-is.'
            ' Some features (e.g. tensor-dependent conditionals and loops) may not'
            ' work as expected.'
            ' Error details can be found in the logs when running with the env'
            ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
            ' AutoGraph team. Cause: %s', target_entity, e)

        return _call_unconverted(f, args, kwargs)

    if kwargs is not None:
        result = converted_f(*effective_args, **kwargs)
    else:
        result = converted_f(*effective_args)

    return result
Exemplo n.º 5
0
def create_source_map(nodes, code, filepath):
  """Creates a source map between an annotated AST and the code it compiles to.

  Note: this function assumes nodes nodes, code and filepath correspond to the
  same code.

  Args:
    nodes: Iterable[ast.AST, ...], one or more AST modes.
    code: Text, the source code in which nodes are found.
    filepath: Text

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
  for node in reparsed_nodes:
    resolve(node, code, filepath, node.lineno, node.col_offset)

  source_map = {}

  try:
    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
      # Note: generated code might not be mapped back to its origin.
      # TODO(mdan): Generated code should always be mapped to something.
      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
      if origin_info is None or final_info is None:
        continue

      # Note: the keys are by line only, excluding the column offset.
      line_loc = LineLocation(final_info.loc.filename, final_info.loc.lineno)

      existing_origin = source_map.get(line_loc)
      if existing_origin is not None:
        # Overlaps may exist because of child nodes, but almost never to
        # different line locations. Exception make decorated functions, where
        # both lines are mapped to the same line in the AST.

        # Line overlaps: keep bottom node.
        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
          if existing_origin.loc.lineno >= origin_info.loc.lineno:
            continue

        # In case of column overlaps, keep the leftmost node.
        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
          continue

      source_map[line_loc] = origin_info

  except ValueError:
    if logging.has_verbosity(3):
      for n, rn in zip(nodes, reparsed_nodes):
        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
        diff = difflib.context_diff(
            nodes_str.split('\n'),
            reparsed_nodes_str.split('\n'),
            fromfile='Original nodes',
            tofile='Reparsed nodes',
            n=7)
        diff = '\n'.join(diff)
        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
    raise

  return source_map
Exemplo n.º 6
0
def converted_call(f, owner, options, args, kwargs):
    """Compiles a function call inline. For internal use only."""
    if owner is not None:
        if not isinstance(f, str):
            raise ValueError(
                'When owner is specified, the function name must be specified as'
                ' a string: {}'.format(f))
        owner_attr = f

        # Special case when the owner is a 'super' object. In that case lookups of
        # dynamic attributes won't work. See
        # inspect_utils.SuperWrapperForDynamicAttrs.
        if isinstance(owner, super):
            owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

        f = getattr(owner, f)

    if logging.has_verbosity(1):
        if owner is not None:
            composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
        else:
            composite_desc = ''

        logging.log(1, 'Converted call: %s %s\n    args: %s\n    kwargs: %s\n',
                    f, composite_desc, args, kwargs)

    if inspect_utils.isbuiltin(f):
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    # TODO(mdan): Clean up the naming inconsistency.
    if hasattr(f, 'autograph_info__') or hasattr(f, '__ag_compiled'):
        logging.log(2, 'Permanently whitelisted: %s: already converted', f)
        return _call_unconverted(f, args, kwargs)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            'Entity {} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will be called without transformation.'
            ' You may however apply AutoGraph before the decorator.'.format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs)

    if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
        return _call_unconverted(f, args, kwargs)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, inspect, re)):
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs)

    # Custom ops and kernels are also permanently whitelisted.
    # See tensorflow.framework.load_library.
    if (hasattr(f, '__module__')
            and hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
        logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
        return _call_unconverted(f, args, kwargs)

    if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        if not tf_inspect.isclass(target_entity):
            if not hasattr(target_entity, '__code__'):
                logging.log(2, 'Permanently whitelisted: %s: native binding',
                            target_entity)
                return _call_unconverted(f, args, kwargs)
            elif (hasattr(target_entity.__code__, 'co_filename')
                  and target_entity.__code__.co_filename == '<string>'):
                # TODO(mdan): __globals__['txt'] might work in Py3.
                logging.log(
                    2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                    target_entity)
                return _call_unconverted(f, args, kwargs)

        converted_f = to_graph(
            target_entity,
            recursive=options.recursive,
            experimental_optional_features=options.optional_features)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if six.PY3:
                logging.log(2, 'KW defaults of %s : %s', converted_f,
                            converted_f.__kwdefaults__)

            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)

            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise
        logging.warn(
            'Entity %s could not be transformed and will be executed as-is.'
            ' Please report this to the AutoGraph team. When filing the bug, set'
            ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
            ' attach the full output. Cause: %s', target_entity, e)
        return _call_unconverted(f, args, kwargs)

    with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
        try:
            if kwargs is not None:
                result = converted_f(*effective_args, **kwargs)
            else:
                result = converted_f(*effective_args)
        except Exception as e:
            _attach_metadata(e, converted_f, True)
            raise

    return result
Exemplo n.º 7
0
def converted_call(f, owner, options, *args, **kwargs):
  """Compiles a function call inline. For internal use only."""
  logging.log(1,
              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
              f, owner, args, kwargs)

  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if inspect_utils.isbuiltin(f):
    return py_builtins.overload_of(f)(*args, **kwargs)

  # TODO(b/122265385): Remove this bypass.
  if ('wrapt' in sys.modules and
      hasattr(sys.modules['wrapt'], 'FunctionWrapper') and
      isinstance(f, sys.modules['wrapt'].FunctionWrapper)):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f), 1)
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return f(*args, **kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if (f in collections.__dict__.values() or f in pdb.__dict__.values() or
      f in copy.__dict__.values()):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return f(*args, **kwargs)

  # TODO(mdan): This needs cleanup.
  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):

    # TODO(mdan): This may be inconsistent in certain situations.
    # If the function had already been annotated with @tf.function, it
    # may be bound to the incorrect object. It's unclear if those situations
    # are possible, but if they happen, we need to check if f is bound
    # to a shim like WeakrefSelf and unpack it.

    # Args typically include `self`, as required by the conversion process.
    # When conversion is skipped, `self` is not necessary, because the
    # original bound method is being executed. This code removes it.
    if tf_inspect.ismethod(f) and args:
      f_self = inspect_utils.getmethodself(f)
      if args[0] is f_self:
        args = args[1:]

    return f(*args, **kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return f(*args, **kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      arg_map_target = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        # If this is a method call, it may or may not include self.
        #
        # Example when self is included:
        #   converted_call(to_graph(foo.bar), foo)
        #
        # Example when self is not included:
        #   super(...).foo(args)
        #
        if owner is not None and (not args or args[0] is not owner):
          effective_args = (owner,) + args
        else:
          # When the owner is not specified, use the result of
          # inspect_utils.getmethodclass.
          # TODO(b/119246461): Make sure an owner is always specified.
          if not args or args[0] is not f_self:
            effective_args = (f_self,) + args
          else:
            effective_args = (f_self,) + args[1:]
        partial_types = (f_self,)
      else:
        effective_args = args
        partial_types = ()

    elif tf_inspect.isclass(f):
      # Constructors
      target_entity = f
      arg_map_target = f.__init__
      effective_args = args
      partial_types = ()

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      arg_map_target = f.__call__
      effective_args = (f,) + args
      partial_types = (f.__class__,)

    else:
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
    arg_types = {}
    for name, arg in arg_values.items():
      arg_class = arg.__class__
      arg_types[name] = (arg_class.__name__, arg_class)

    # When called from within a decorator, this is the only indication that
    # the function is a method - it appears that the decorator is applied
    # before the method is bound.
    if not partial_types:
      if 'self' in arg_values:
        if tf_inspect.isclass(arg_values['self'].__class__):
          partial_types = (arg_values['self'].__class__,)
      elif 'cls' in arg_values:
        if tf_inspect.isclass(arg_values['cls']):
          partial_types = (arg_values['cls'],)

    logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
                partial_types)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=arg_values,
        arg_types=arg_types,
        experimental_optional_features=options.optional_features,
        experimental_strip_decorators=options.strip_decorators,
        experimental_verbose=options.verbose,
        experimental_partial_types=partial_types)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    logging.warn(
        'Entity %s could not be transformed and will be staged without change.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY=5. Please report this to the AutoGraph'
        ' team. Cause: %s', target_entity, e)

    return f(*args, **kwargs)

  result = converted_f(*effective_args, **kwargs)

  # The converted function's closure is simply inserted into the function's
  # module __dict__. Since modules are permanently cached, that results in
  # leaking the entire closure.
  # Normally, it's not safe to delete the module because that may release said
  # closure as well. However, in the case of converted_call we are certain the
  # function will not be executed again, so the closure should no longer be
  # needed so long as the function doesn't return any executable code.
  # TODO(mdan): Attach the closure properly, using cells.
  if all(map(_is_not_callable, nest.flatten(result))):
    del sys.modules[converted_f.__module__]

  return result
Exemplo n.º 8
0
    def transform_function(self, fn, user_context):
        """Transforms a function. See GenericTranspiler.trasnform_function.

    This overload wraps the parent's `transform_function`, adding caching and
    facilities to instantiate the output as a Python object. It also
    adds facilities to make new symbols available to the generated Python code,
    visible as local variables - see `get_extra_locals`.

    Args:
      fn: A function or lambda.
      user_context: An opaque object (may be None) that is forwarded to
        transform_ast, through the ctx.user_context argument.
    Returns:
      A tuple:
        * A function or lambda with the same signature and closure as `fn`
        * The temporary module into which the transformed function was loaded
        * The source map as a
            Dict[origin_info.LineLocation, origin_info.OriginInfo]
    """
        cache_subkey = self.get_caching_key(user_context)

        if self._cache.has(fn, cache_subkey):
            # Fast path: use a lock-free check.
            factory = self._cached_factory(fn, cache_subkey)

        else:
            with self._cache_lock:
                # Check again under lock.
                if self._cache.has(fn, cache_subkey):
                    factory = self._cached_factory(fn, cache_subkey)

                else:
                    logging.log(1, '%s is not cached for subkey %s', fn,
                                cache_subkey)
                    # TODO(mdan): Confusing overloading pattern. Fix.
                    nodes, ctx = super(PyToPy, self).transform_function(
                        fn, user_context)

                    if isinstance(nodes, gast.Lambda):
                        nodes = gast.Assign(targets=[
                            gast.Name(ctx.info.name,
                                      ctx=gast.Store(),
                                      annotation=None,
                                      type_comment=None)
                        ],
                                            value=nodes)
                    else:
                        nodes.name = ctx.info.name

                    if logging.has_verbosity(2):
                        logging.log(2, 'Transformed %s:\n\n%s\n', fn,
                                    parser.unparse(nodes))

                    factory = _PythonFnFactory(ctx.info.name,
                                               fn.__code__.co_freevars,
                                               self.get_extra_locals())
                    factory.create(nodes,
                                   ctx.namer,
                                   future_features=ctx.info.future_features)
                    self._cache[fn][cache_subkey] = factory

        transformed_fn = factory.instantiate(globals_=fn.__globals__,
                                             closure=fn.__closure__ or (),
                                             defaults=fn.__defaults__,
                                             kwdefaults=getattr(
                                                 fn, '__kwdefaults__', None))
        return transformed_fn, factory.module, factory.source_map
Exemplo n.º 9
0
def create_source_map(nodes, code, filename):
  """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]

  Returns:
    Dict[LineLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
  reparsed_nodes = parser.parse_str(code, preamble_len=0, single_node=False)
  for node in reparsed_nodes:
    resolve(node, code)

  result = {}

  try:
    for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
      # Note: generated code might not be mapped back to its origin.
      # TODO(mdan): Generated code should always be mapped to something.
      origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
      final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
      if origin_info is None or final_info is None:
        continue

      line_loc = LineLocation(filename, final_info.loc.lineno)

      existing_origin = result.get(line_loc)
      if existing_origin is not None:
        # Overlaps may exist because of child nodes, but almost never to
        # different line locations. Exception make decorated functions, where
        # both lines are mapped to the same line in the AST.

        # Line overlaps: keep bottom node.
        if existing_origin.loc.line_loc == origin_info.loc.line_loc:
          if existing_origin.loc.lineno >= origin_info.loc.lineno:
            continue

        # In case of overlaps, keep the leftmost node.
        if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
          continue

      result[line_loc] = origin_info
  except ValueError:
    if logging.has_verbosity(3):
      for n, rn in zip(nodes, reparsed_nodes):
        nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
        reparsed_nodes_str = pretty_printer.fmt(rn, color=False, noanno=True)
        diff = difflib.context_diff(
            nodes_str.split('\n'),
            reparsed_nodes_str.split('\n'),
            fromfile='Original nodes',
            tofile='Reparsed nodes',
            n=7)
        diff = '\n'.join(diff)
        logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
    raise

  return result
Exemplo n.º 10
0
def converted_call(f, options, args, kwargs, caller_fn_scope=None):
    """Compiles a function call inline.

  For internal use only.

  Args:
    f: The function to convert.
    options: converter.ConversionOptions
    args: Tuple, the original positional arguments of f
    kwargs: Dict, the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.

  Returns:
    Any, the result of executing a possibly-converted `f` with the given
      arguments.
  """
    logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f,
                args, kwargs)

    if conversion.check_cached_unconverted(f, options):
        return _call_unconverted(f, args, kwargs, options, False)

    if inspect_utils.isbuiltin(f):
        if f is eval:
            return py_builtins.eval_in_original_context(
                f, args, caller_fn_scope)
        if f is super:
            return py_builtins.super_in_original_context(
                f, args, caller_fn_scope)
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    # TODO(mdan): Clean up the naming inconsistency.
    if hasattr(f, 'autograph_info__') or hasattr(f, '__ag_compiled'):
        logging.log(2, 'Permanently whitelisted: %s: already converted', f)
        return _call_unconverted(f, args, kwargs, options)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            '{} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will run as-is.'
            ' You may still apply AutoGraph before the wrapt decorator.'.
            format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs, options)

    if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
        logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
        return _call_unconverted(f, args, kwargs, options)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs, options)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, inspect, re)):
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs, options)

    # Custom ops and kernels are also permanently whitelisted.
    # See tensorflow.framework.load_library.
    if (hasattr(f, '__module__')
            and hasattr(f.__module__, '_IS_TENSORFLOW_PLUGIN')):
        logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', f)
        return _call_unconverted(f, args, kwargs, options)

    if not options.user_requested and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs, options)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs, options)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        # TODO(b/120224672): This unwrapping should be done before the checks above.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            if kwargs is not None:
                new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            effective_args = (f, ) + args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            effective_args = args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        if not tf_inspect.isclass(target_entity):
            if not hasattr(target_entity, '__code__'):
                logging.log(2, 'Permanently whitelisted: %s: native binding',
                            target_entity)
                return _call_unconverted(f, args, kwargs, options)
            elif (hasattr(target_entity.__code__, 'co_filename')
                  and target_entity.__code__.co_filename == '<string>'):
                # TODO(mdan): __globals__['txt'] might work in Py3.
                logging.log(
                    2, 'Permanently whitelisted: %s: dynamic code (exec?)',
                    target_entity)
                return _call_unconverted(f, args, kwargs, options)

        program_ctx = converter.ProgramContext(
            options=options,
            autograph_module=tf_inspect.getmodule(converted_call))
        converted_f = conversion.convert(target_entity, program_ctx)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            if six.PY3:
                logging.log(2, 'KW defaults of %s : %s', converted_f,
                            converted_f.__kwdefaults__)

            if kwargs is not None:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                                  **kwargs)
            else:
                callargs = tf_inspect.getcallargs(converted_f, *effective_args)

            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise
        if _errors_are_normally_possible(target_entity, e):
            logging.warn(
                'AutoGraph could not transform %s and will run it as-is.\n'
                'Cause: %s', target_entity, e)
        else:
            logging.warn(
                'AutoGraph could not transform %s and will run it as-is.\n'
                'Please report this to the TensorFlow team. When filing the bug, set'
                ' the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and'
                ' attach the full output.\n'
                'Cause: %s', target_entity, e)
        return _call_unconverted(f, args, kwargs, options)

    with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
        try:
            if kwargs is not None:
                result = converted_f(*effective_args, **kwargs)
            else:
                result = converted_f(*effective_args)
        except Exception as e:
            _attach_metadata(e, converted_f, True)
            raise

    return result
Exemplo n.º 11
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
  """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
  logging.log(1, 'Converting %s', o)

  if tf_inspect.isclass(o):
    node, name, ns = class_to_graph(o, program_ctx)
  elif tf_inspect.isfunction(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  elif tf_inspect.ismethod(o):
    node, name, ns = function_to_graph(o, program_ctx, arg_values, arg_types)
  # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
  elif hasattr(o, '__class__'):
    raise NotImplementedError(
        'Object conversion is not yet supported. If you are '
        'trying to convert code that uses an existing object, '
        'try including the creation of that object in the '
        'conversion. For example, instead of converting the method '
        'of a class, try converting the entire class instead. '
        'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
        'python/autograph/README.md#using-the-functional-api '
        'for more information.')
  else:
    raise ValueError(
        'Entity "%s" has unsupported type "%s". Only functions and classes are '
        'supported for now.' % (o, type(o)))

  # TODO(mdan): This is temporary. it should be created using a converter.
  # TODO(mdan): The attribute should be added with a helper, not directly.
  # The helper can ensure there are no collisions.
  template = '''
      entity.autograph_info__ = {}
  '''
  node.extend(templates.replace(template, entity=name))

  program_ctx.add_to_cache(o, node)

  if logging.has_verbosity(2):
    logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                compiler.ast_to_source(node))
  if logging.has_verbosity(4):
    for n in node:
      logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                  pretty_printer.fmt(n, color=False))

  if program_ctx.options.recursive:
    while True:
      candidate = None
      for obj in program_ctx.name_map.keys():
        if obj not in program_ctx.dependency_cache:
          candidate = obj
          break
      if candidate is None:
        break
      if (hasattr(candidate, 'im_class') and
          getattr(candidate, 'im_class') not in program_ctx.partial_types):
        # Class members are converted with their objects, unless they're
        # only converted partially.
        continue
      entity_to_graph(candidate, program_ctx, {}, {})

  return node, name, ns
Exemplo n.º 12
0
def entity_to_graph(o, program_ctx, arg_values, arg_types):
    """Compile a Python entity into equivalent TensorFlow.

  The function will also recursively compile all the entities that `o`
  references, updating `dependency_cache`.

  This function is reentrant, and relies on dependency_cache to avoid
  generating duplicate code.

  Args:
    o: A Python entity.
    program_ctx: A ProgramContext object.
    arg_values: A dict containing value hints for symbols like function
        parameters.
    arg_types: A dict containing type hints for symbols like function
        parameters.

  Returns:
    A tuple (ast, new_name, namespace):
        * ast: An AST representing an entity with interface equivalent to `o`,
            but which when executed it creates TF a graph.
        * new_name: The symbol name under which the new entity can be found.
        * namespace: A dict mapping all symbols visible to the converted entity,
            keyed by their symbol name.

  Raises:
    ValueError: if the entity type is not supported.
  """
    logging.log(1, 'Converting %s', o)

    if tf_inspect.isclass(o):
        node, name, ns = class_to_graph(o, program_ctx)
    elif tf_inspect.isfunction(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    elif tf_inspect.ismethod(o):
        node, name, ns = function_to_graph(o, program_ctx, arg_values,
                                           arg_types)
    # TODO(mdan,yashkatariya): Remove when object conversion is implemented.
    elif hasattr(o, '__class__'):
        raise NotImplementedError(
            'Object conversion is not yet supported. If you are '
            'trying to convert code that uses an existing object, '
            'try including the creation of that object in the '
            'conversion. For example, instead of converting the method '
            'of a class, try converting the entire class instead. '
            'See https://github.com/tensorflow/tensorflow/blob/master/tensorflow/'
            'contrib/autograph/README.md#using-the-functional-api '
            'for more information.')
    else:
        raise ValueError(
            'Entity "%s" has unsupported type "%s". Only functions and classes are '
            'supported for now.' % (o, type(o)))

    # TODO(mdan): This is temporary. it should be created using a converter.
    # TODO(mdan): The attribute should be added with a helper, not directly.
    # The helper can ensure there are no collisions.
    template = '''
      entity.autograph_info__ = {}
  '''
    node.extend(templates.replace(template, entity=name))

    program_ctx.add_to_cache(o, node)

    if logging.has_verbosity(2):
        logging.log(2, 'Compiled output of %s:\n\n%s\n', o,
                    compiler.ast_to_source(node))
    if logging.has_verbosity(4):
        for n in node:
            logging.log(4, 'Compiled AST of %s:\n\n%s\n\n', o,
                        pretty_printer.fmt(n, color=False))

    if program_ctx.options.recursive:
        while True:
            candidate = None
            for obj in program_ctx.name_map.keys():
                if obj not in program_ctx.dependency_cache:
                    candidate = obj
                    break
            if candidate is None:
                break
            if (hasattr(candidate, 'im_class') and getattr(
                    candidate, 'im_class') not in program_ctx.partial_types):
                # Class members are converted with their objects, unless they're
                # only converted partially.
                continue
            entity_to_graph(candidate, program_ctx, {}, {})

    return node, name, ns
Exemplo n.º 13
0
def converted_call(f, owner, options, args, kwargs):
    """Compiles a function call inline. For internal use only."""
    logging.log(
        1, 'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n', f,
        owner, args, kwargs)

    if owner is not None:
        if not isinstance(f, str):
            raise ValueError(
                'When owner is specified, the function name must be specified as'
                ' a string: {}'.format(f))

        # Special case when the owner is a 'super' object. In that case lookups of
        # dynamic attributes won't work. See
        # inspect_utils.SuperWrapperForDynamicAttrs.
        if isinstance(owner, super):
            owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

        f = getattr(owner, f)

    if inspect_utils.isbuiltin(f):
        return py_builtins.overload_of(f)(*args, **kwargs)

    if _is_known_loaded_type(f, 'weakref', 'ref'):
        logging.log(2, 'Permanently whitelisted: %s: weakref', f)
        return _call_unconverted(f, args, kwargs)

    # TODO(b/122265385): Remove this bypass.
    if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper')
            or _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
        logging.warn(
            'Entity {} appears to be decorated by wrapt, which is not yet supported'
            ' by AutoGraph. The function will be called without transformation.'
            ' You may however apply AutoGraph before the decorator.'.format(f))
        logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
        return _call_unconverted(f, args, kwargs)

    # Constructors are permanently whitelisted.
    # TODO(mdan): Toggle as experimental feature instead.
    # TODO(b/124016764): Remove this limitation.
    if tf_inspect.isclass(f):
        logging.log(2, 'Permanently whitelisted: %s: constructor', f)
        return _call_unconverted(f, args, kwargs)

    # Other built-in modules are permanently whitelisted.
    # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
    # Note: TF linter disallows importing inspect.
    if any(f in m.__dict__.values()
           for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
        logging.log(2, 'Permanently whitelisted: %s: part of builtin module',
                    f)
        return _call_unconverted(f, args, kwargs)

    if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
        return _call_unconverted(f, args, kwargs)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs)

    # TODO(mdan): Move this entire block inside to_graph.
    try:  # Begin of transformation error guards

        # Unwrap functools.partial objects
        # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
        while isinstance(f, functools.partial):
            args = f.args + args
            new_kwargs = {}
            if f.keywords is not None:
                new_kwargs.update(f.keywords)
            new_kwargs.update(kwargs)
            kwargs = new_kwargs
            f = f.func

        if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
            # Regular functions
            target_entity = f
            arg_map_target = f
            f_self = inspect_utils.getmethodself(f)

            # TODO(b/119246461): This may be more elegantly handled using __get__?
            if f_self is not None:
                effective_args = (f_self, ) + args
            else:
                effective_args = args

        elif tf_inspect.isclass(f):
            # Constructors
            # Note: Until we support class constructurs, and enable whole-class
            # conversion with an experimental flag, this branch is dead code.
            # TODO(mdan): Consider removing unless there is a compelling use case.
            target_entity = f
            arg_map_target = f.__init__
            effective_args = args

        elif hasattr(f, '__call__') and hasattr(f, '__class__'):
            # Callable objects
            target_entity = f.__call__
            arg_map_target = f.__call__
            effective_args = (f, ) + args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

        arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
        arg_types = {}
        for name, arg in arg_values.items():
            arg_class = arg.__class__
            arg_types[name] = (arg_class.__name__, arg_class)

        converted_f = to_graph(
            target_entity,
            recursive=options.recursive,
            arg_values=arg_values,
            arg_types=arg_types,
            experimental_optional_features=options.optional_features)

        if logging.has_verbosity(2):
            logging.log(2, 'Defaults of %s : %s', converted_f,
                        converted_f.__defaults__)
            callargs = tf_inspect.getcallargs(converted_f, *effective_args,
                                              **kwargs)
            formatted_callargs = '\n'.join('    {}: {}'.format(k, v)
                                           for k, v in callargs.items())
            logging.log(2, 'Calling %s with\n%s\n', converted_f,
                        formatted_callargs)

    # TODO(mdan): Reduce this list.
    except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
            KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
            ValueError, IOError) as e:

        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)

        if is_autograph_strict_conversion_mode():
            raise

        logging.warn(
            'Entity %s could not be transformed and will be staged without change.'
            ' Error details can be found in the logs when running with the env'
            ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
            ' AutoGraph team. Cause: %s', target_entity, e)

        return _call_unconverted(f, args, kwargs)

    result = converted_f(*effective_args, **kwargs)

    # The converted function's closure is simply inserted into the function's
    # module __dict__. Since modules are permanently cached, that results in
    # leaking the entire closure.
    # Normally, it's not safe to delete the module because that may release said
    # closure as well. However, in the case of converted_call we are certain the
    # function will not be executed again, so the closure should no longer be
    # needed so long as the function doesn't return any executable code.
    # TODO(mdan): Attach the closure properly, using cells.
    if all(map(_is_not_callable, nest.flatten(result))):
        del sys.modules[converted_f.__module__]

    return result
Exemplo n.º 14
0
def converted_call(f, args, kwargs, caller_fn_scope=None, options=None):
    """Converts a function call inline.

  For internal use only.

  Note: The argument list is optimized for readability of generated code, which
  may look like this:

    ag__.converted_call(f, (arg1, arg2), None, fscope)
    ag__.converted_call(f, (), dict(arg1=val1, **kwargs), fscope)
    ag__.converted_call(f, (arg1, arg2) + varargs, dict(**kwargs), lscope)

  Args:
    f: The function to convert.
    args: Tuple, the original positional arguments of f
    kwargs: Optional[Dict], the original keyword arguments of f
    caller_fn_scope: Optional[function_wrappers.FunctionScope], the function
      scope of the converted function in which this call was originally made.
    options: Optional[converter.ConversionOptions], conversion options. If not
      specified, the value of caller_fn_scope.callopts is used. Either options
      or caller_fn_scope must be present.

  Returns:
    Any, the result of executing a possibly-converted `f` with the given
      arguments.
  """
    logging.log(1, 'Converted call: %s\n    args: %s\n    kwargs: %s\n', f,
                args, kwargs)

    if options is None:
        if caller_fn_scope is None:
            raise ValueError(
                'either caller_fn_scope or options must have a value')
        options = caller_fn_scope.callopts

    if conversion.is_in_allowlist_cache(f, options):
        logging.log(2, 'Allowlisted %s: from cache', f)
        return _call_unconverted(f, args, kwargs, options, False)

    if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED:
        logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f)
        return _call_unconverted(f, args, kwargs, options, False)

    if is_autograph_artifact(f):
        logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f)
        return _call_unconverted(f, args, kwargs, options)

    # If this is a partial, unwrap it and redo all the checks.
    if isinstance(f, functools.partial):
        new_kwargs = {}
        if f.keywords is not None:
            # Use copy to avoid mutating the underlying keywords.
            new_kwargs = f.keywords.copy()
        if kwargs is not None:
            new_kwargs.update(kwargs)
        new_args = f.args + args
        logging.log(3, 'Forwarding call of partial %s with\n%s\n%s\n', f,
                    new_args, new_kwargs)
        return converted_call(f.func,
                              new_args,
                              new_kwargs,
                              caller_fn_scope=caller_fn_scope,
                              options=options)

    if inspect_utils.isbuiltin(f):
        if f is eval:
            return py_builtins.eval_in_original_context(
                f, args, caller_fn_scope)
        if f is super:
            return py_builtins.super_in_original_context(
                f, args, caller_fn_scope)
        if f is globals:
            return py_builtins.globals_in_original_context(caller_fn_scope)
        if f is locals:
            return py_builtins.locals_in_original_context(caller_fn_scope)
        if kwargs:
            return py_builtins.overload_of(f)(*args, **kwargs)
        else:
            return py_builtins.overload_of(f)(*args)

    if conversion.is_unsupported(f):
        return _call_unconverted(f, args, kwargs, options)

    if not options.user_requested and conversion.is_allowlisted(f):
        return _call_unconverted(f, args, kwargs, options)

    # internal_convert_user_code is for example turned off when issuing a dynamic
    # call conversion from generated code while in nonrecursive mode. In that
    # case we evidently don't want to recurse, but we still have to convert
    # things like builtins.
    if not options.internal_convert_user_code:
        return _call_unconverted(f, args, kwargs, options)

    try:
        if inspect.ismethod(f) or inspect.isfunction(f):
            target_entity = f
            effective_args = args

            f_self = getattr(f, '__self__', None)
            if f_self is not None:
                if isinstance(f_self, function.TfMethodTarget):
                    f_self = f_self.target
                effective_args = (f_self, ) + effective_args

        elif hasattr(f, '__class__') and hasattr(f.__class__, '__call__'):
            # Callable objects. Dunder methods have special lookup rules, see:
            # https://docs.python.org/3/reference/datamodel.html#specialnames
            # TODO(mdan): Recurse into converted_call to simplify other verifications.
            # This should be handled in the same way as partials.
            target_entity = f.__class__.__call__
            effective_args = (f, ) + args

        else:
            target_entity = f
            raise NotImplementedError('unknown callable type "%s"' % type(f))

    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise
        return _fall_back_unconverted(f, args, kwargs, options, e)

    if not hasattr(target_entity, '__code__'):
        logging.log(2, 'Permanently allowed: %s: native binding',
                    target_entity)
        return _call_unconverted(f, args, kwargs, options)
    elif (hasattr(target_entity.__code__, 'co_filename')
          and target_entity.__code__.co_filename == '<string>'):
        # TODO(mdan): __globals__['txt'] might work in Py3.
        logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)',
                    target_entity)
        return _call_unconverted(f, args, kwargs, options)

    try:
        program_ctx = converter.ProgramContext(options=options)
        converted_f = _convert_actual(target_entity, program_ctx)
        if logging.has_verbosity(2):
            _log_callargs(converted_f, effective_args, kwargs)
    except Exception as e:  # pylint:disable=broad-except
        logging.log(1,
                    'Error transforming entity %s',
                    target_entity,
                    exc_info=True)
        if is_autograph_strict_conversion_mode():
            raise
        return _fall_back_unconverted(f, args, kwargs, options, e)

    with StackTraceMapper(converted_f), tf_stack.CurrentModuleFilter():
        try:
            if kwargs is not None:
                result = converted_f(*effective_args, **kwargs)
            else:
                result = converted_f(*effective_args)
        except Exception as e:
            _attach_error_metadata(e, converted_f)
            raise

    return result
Exemplo n.º 15
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))
    owner_attr = f

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if logging.has_verbosity(1):
    if owner is not None:
      composite_desc = '("{}" attr of {})'.format(owner_attr, owner)
    else:
      composite_desc = ''

    logging.log(1,
                'Converted call: %s %s\n    args: %s\n    kwargs: %s\n',
                f, composite_desc, args, kwargs)

  if inspect_utils.isbuiltin(f):
    if kwargs:
      return py_builtins.overload_of(f)(*args, **kwargs)
    else:
      return py_builtins.overload_of(f)(*args)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs)

  if _is_known_loaded_type(f, 'functools', '_lru_cache_wrapper'):
    logging.log(2, 'Permanently whitelisted: %s: lru_cache', f)
    return _call_unconverted(f, args, kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  if any(f in m.__dict__.values() for m in (collections, pdb, copy, inspect)):
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs)

  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
    return _call_unconverted(f, args, kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      if kwargs is not None:
        new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        effective_args = (f_self,) + args
      else:
        effective_args = args

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      effective_args = args

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      effective_args = (f,) + args

    else:
      target_entity = f
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    if (not tf_inspect.isclass(target_entity) and
        not hasattr(target_entity, '__code__')):
      logging.log(
          2, 'Permanently whitelisted: %s: native binding', target_entity)
      return _call_unconverted(f, args, kwargs)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=None,
        arg_types=None,
        experimental_optional_features=options.optional_features)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      if kwargs is not None:
        callargs = tf_inspect.getcallargs(
            converted_f, *effective_args, **kwargs)
      else:
        callargs = tf_inspect.getcallargs(converted_f, *effective_args)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:

    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)

    if is_autograph_strict_conversion_mode():
      raise

    logging.warn(
        'Entity %s could not be transformed and will be executed as-is.'
        ' Some features (e.g. tensor-dependent conditionals and loops) may not'
        ' work as expected.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
        ' AutoGraph team. Cause: %s', target_entity, e)

    return _call_unconverted(f, args, kwargs)

  if kwargs is not None:
    result = converted_f(*effective_args, **kwargs)
  else:
    result = converted_f(*effective_args)

  return result
Exemplo n.º 16
0
def converted_call(f, owner, options, args, kwargs):
  """Compiles a function call inline. For internal use only."""
  logging.log(1,
              'Converted call: %s; owner: %s\n    args: %s\n    kwargs: %s\n',
              f, owner, args, kwargs)

  if owner is not None:
    if not isinstance(f, str):
      raise ValueError(
          'When owner is specified, the function name must be specified as'
          ' a string: {}'.format(f))

    # Special case when the owner is a 'super' object. In that case lookups of
    # dynamic attributes won't work. See
    # inspect_utils.SuperWrapperForDynamicAttrs.
    if isinstance(owner, super):
      owner = inspect_utils.SuperWrapperForDynamicAttrs(owner)

    f = getattr(owner, f)

  if inspect_utils.isbuiltin(f):
    return py_builtins.overload_of(f)(*args, **kwargs)

  # TODO(b/122265385): Remove this bypass.
  if (_is_known_loaded_type(f, 'wrapt', 'FunctionWrapper') or
      _is_known_loaded_type(f, 'wrapt', 'BoundFunctionWrapper')):
    logging.warn(
        'Entity {} appears to be decorated by wrapt, which is not yet supported'
        ' by AutoGraph. The function will be called without transformation.'
        ' You may however apply AutoGraph before the decorator.'.format(f))
    logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', f)
    return _call_unconverted(f, args, kwargs)

  # Constructors are permanently whitelisted.
  # TODO(mdan): Toggle as experimental feature instead.
  # TODO(b/124016764): Remove this limitation.
  if tf_inspect.isclass(f):
    logging.log(2, 'Permanently whitelisted: %s: constructor', f)
    return _call_unconverted(f, args, kwargs)

  # Other built-in modules are permanently whitelisted.
  # TODO(mdan): Figure out how to do this consistently for all stdlib modules.
  # Note: TF linter disallows importing inspect.
  if any(f in m.__dict__.values()
         for m in (collections, pdb, copy, tf_inspect._inspect)):  # pylint:disable=protected-access
    logging.log(2, 'Permanently whitelisted: %s: part of builtin module', f)
    return _call_unconverted(f, args, kwargs)

  if not options.force_conversion and conversion.is_whitelisted_for_graph(f):
    return _call_unconverted(f, args, kwargs)

  # internal_convert_user_code is for example turned off when issuing a dynamic
  # call conversion from generated code while in nonrecursive mode. In that
  # case we evidently don't want to recurse, but we still have to convert
  # things like builtins.
  if not options.internal_convert_user_code:
    return _call_unconverted(f, args, kwargs)

  # TODO(mdan): Move this entire block inside to_graph.
  try:  # Begin of transformation error guards

    # Unwrap functools.partial objects
    # TODO(mdan): Consider sharing unwrapping logic with tf_inspect.
    while isinstance(f, functools.partial):
      args = f.args + args
      new_kwargs = {}
      if f.keywords is not None:
        new_kwargs.update(f.keywords)
      new_kwargs.update(kwargs)
      kwargs = new_kwargs
      f = f.func

    if tf_inspect.isfunction(f) or tf_inspect.ismethod(f):
      # Regular functions
      target_entity = f
      arg_map_target = f
      f_self = inspect_utils.getmethodself(f)

      # TODO(b/119246461): This may be more elegantly handled using __get__?
      if f_self is not None:
        # If this is a method call, it may or may not include self.
        #
        # Example when self is included:
        #   converted_call(to_graph(foo.bar), foo)
        #
        # Example when self is not included:
        #   super(...).foo(args)
        #
        if owner is not None and (not args or args[0] is not owner):
          effective_args = (owner,) + args
        else:
          # When the owner is not specified, use the result of
          # inspect_utils.getmethodclass.
          # TODO(b/119246461): Make sure an owner is always specified.
          if not args or args[0] is not f_self:
            effective_args = (f_self,) + args
          else:
            effective_args = (f_self,) + args[1:]
        partial_types = (f_self,)
      else:
        effective_args = args
        partial_types = ()

    elif tf_inspect.isclass(f):
      # Constructors
      # Note: Until we support class constructurs, and enable whole-class
      # conversion with an experimental flag, this branch is dead code.
      # TODO(mdan): Consider removing unless there is a compelling use case.
      target_entity = f
      arg_map_target = f.__init__
      effective_args = args
      partial_types = ()

    elif hasattr(f, '__call__') and hasattr(f, '__class__'):
      # Callable objects
      target_entity = f.__call__
      arg_map_target = f.__call__
      effective_args = (f,) + args
      partial_types = (f.__class__,)

    else:
      raise NotImplementedError('unknown callable type "%s"' % type(f))

    arg_values = tf_inspect.getcallargs(arg_map_target, *args, **kwargs)
    arg_types = {}
    for name, arg in arg_values.items():
      arg_class = arg.__class__
      arg_types[name] = (arg_class.__name__, arg_class)

    # When called from within a decorator, this is the only indication that
    # the function is a method - it appears that the decorator is applied
    # before the method is bound.
    if not partial_types:
      if 'self' in arg_values:
        if tf_inspect.isclass(arg_values['self'].__class__):
          partial_types = (arg_values['self'].__class__,)
      elif 'cls' in arg_values:
        if tf_inspect.isclass(arg_values['cls']):
          partial_types = (arg_values['cls'],)

    logging.log(3, 'Partial types in conversion of %s: %s', target_entity,
                partial_types)

    converted_f = to_graph(
        target_entity,
        recursive=options.recursive,
        arg_values=arg_values,
        arg_types=arg_types,
        experimental_optional_features=options.optional_features,
        experimental_strip_decorators=options.strip_decorators,
        experimental_verbose=options.verbose,
        experimental_partial_types=partial_types)

    if logging.has_verbosity(2):
      logging.log(2, 'Defaults of %s : %s', converted_f,
                  converted_f.__defaults__)
      callargs = tf_inspect.getcallargs(converted_f, *effective_args, **kwargs)
      formatted_callargs = '\n'.join(
          '    {}: {}'.format(k, v) for k, v in callargs.items())
      logging.log(2, 'Calling %s with\n%s\n', converted_f, formatted_callargs)

  # TODO(mdan): Reduce this list.
  except (errors.AutoGraphError, AssertionError, AttributeError, IndexError,
          KeyError, NameError, NotImplementedError, SyntaxError, TypeError,
          ValueError, IOError) as e:
    logging.log(1, 'Error transforming entity %s', target_entity, exc_info=True)
    logging.warn(
        'Entity %s could not be transformed and will be staged without change.'
        ' Error details can be found in the logs when running with the env'
        ' variable AUTOGRAPH_VERBOSITY >= 1. Please report this to the'
        ' AutoGraph team. Cause: %s', target_entity, e)

    return _call_unconverted(f, args, kwargs)

  result = converted_f(*effective_args, **kwargs)

  # The converted function's closure is simply inserted into the function's
  # module __dict__. Since modules are permanently cached, that results in
  # leaking the entire closure.
  # Normally, it's not safe to delete the module because that may release said
  # closure as well. However, in the case of converted_call we are certain the
  # function will not be executed again, so the closure should no longer be
  # needed so long as the function doesn't return any executable code.
  # TODO(mdan): Attach the closure properly, using cells.
  if all(map(_is_not_callable, nest.flatten(result))):
    del sys.modules[converted_f.__module__]

  return result
Exemplo n.º 17
0
def create_source_map(nodes, code, filename, indices_in_code):
    """Creates a source map between an annotated AST and the code it compiles to.

  Args:
    nodes: Iterable[ast.AST, ...]
    code: Text
    filename: Optional[Text]
    indices_in_code: Union[int, Iterable[int, ...]], the positions at which
        nodes appear in code. The parser always returns a module when parsing
        code. This argument indicates the position in that module's body at
        which the corresponding of node should appear.

  Returns:
    Dict[CodeLocation, OriginInfo], mapping locations in code to locations
    indicated by origin annotations in node.
  """
    reparsed_nodes = parser.parse_str(code)
    reparsed_nodes = [reparsed_nodes.body[i] for i in indices_in_code]

    resolve(reparsed_nodes, code)
    result = {}

    try:
        for before, after in ast_util.parallel_walk(nodes, reparsed_nodes):
            # Note: generated code might not be mapped back to its origin.
            # TODO(mdan): Generated code should always be mapped to something.
            origin_info = anno.getanno(before, anno.Basic.ORIGIN, default=None)
            final_info = anno.getanno(after, anno.Basic.ORIGIN, default=None)
            if origin_info is None or final_info is None:
                continue

            line_loc = LineLocation(filename, final_info.loc.lineno)

            existing_origin = result.get(line_loc)
            if existing_origin is not None:
                # Overlaps may exist because of child nodes, but almost never to
                # different line locations. Exception make decorated functions, where
                # both lines are mapped to the same line in the AST.

                # Line overlaps: keep bottom node.
                if existing_origin.loc.line_loc == origin_info.loc.line_loc:
                    if existing_origin.loc.lineno >= origin_info.loc.lineno:
                        continue

                # In case of overlaps, keep the leftmost node.
                if existing_origin.loc.col_offset <= origin_info.loc.col_offset:
                    continue

            result[line_loc] = origin_info
    except ValueError:
        if logging.has_verbosity(3):
            for n, rn in zip(nodes, reparsed_nodes):
                nodes_str = pretty_printer.fmt(n, color=False, noanno=True)
                reparsed_nodes_str = pretty_printer.fmt(rn,
                                                        color=False,
                                                        noanno=True)
                diff = difflib.context_diff(nodes_str.split('\n'),
                                            reparsed_nodes_str.split('\n'),
                                            fromfile='Original nodes',
                                            tofile='Reparsed nodes',
                                            n=7)
                diff = '\n'.join(diff)
                logging.log(3, 'AST seems to lack integrity. Diff:\n%s', diff)
        raise

    return result