Beispiel #1
0
  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names
    api_names_attr_v1 = API_ATTRS_V1[self._api_name].names
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)
      delattr(undecorated_f, api_names_attr_v1)

    _, undecorated_func = tf_decorator.unwrap(func)
    self.set_attr(undecorated_func, api_names_attr, self._names)
    self.set_attr(undecorated_func, api_names_attr_v1, self._names_v1)
    return func
Beispiel #2
0
  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names
        and kwarg `allow_multiple_exports` not set.
    """
    api_names_attr = API_ATTRS[self._api_name].names

    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      delattr(undecorated_f, api_names_attr)

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if api_names_attr in undecorated_func.__dict__:
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, getattr(
              undecorated_func, api_names_attr)))  # pylint: disable=protected-access
    setattr(undecorated_func, api_names_attr, self._names)
    return func
Beispiel #3
0
  def __call__(self, func):
    """Calls this decorator.

    Args:
      func: decorated symbol (function or class).

    Returns:
      The input function with _tf_api_names attribute set.

    Raises:
      SymbolAlreadyExposedError: Raised when a symbol already has API names.
    """
    # Undecorate overridden names
    for f in self._overrides:
      _, undecorated_f = tf_decorator.unwrap(f)
      del undecorated_f._tf_api_names  # pylint: disable=protected-access

    _, undecorated_func = tf_decorator.unwrap(func)

    # Check for an existing api. We check if attribute name is in
    # __dict__ instead of using hasattr to verify that subclasses have
    # their own _tf_api_names as opposed to just inheriting it.
    if '_tf_api_names' in undecorated_func.__dict__:
      # pylint: disable=protected-access
      raise SymbolAlreadyExposedError(
          'Symbol %s is already exposed as %s.' %
          (undecorated_func.__name__, undecorated_func._tf_api_names))
      # pylint: enable=protected-access

    # Complete the export by creating/overriding attribute
    # pylint: disable=protected-access
    undecorated_func._tf_api_names = self._names
    # pylint: enable=protected-access
    return func
Beispiel #4
0
def _op_is_in_tf_version(op, version):
  if version == 1:
    return (tf_export.get_v1_names(tf_decorator.unwrap(op)[1]) or
            op in _V1_OPS_THAT_DELEGATE_TO_V2_OPS)
  elif version == 2:
    return tf_export.get_v2_names(tf_decorator.unwrap(op)[1])
  else:
    raise ValueError('Expected version 1 or 2.')
 def testUnwrapReturnsDecoratorListFromOutermostToInnermost(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual('decorator 1', decorators[0].decorator_name)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[1].decorator_name)
   self.assertEqual('decorator 3', decorators[2].decorator_name)
   self.assertEqual('decorator 3 documentation', decorators[2].decorator_doc)
 def testUnwrapBoundMethods(self):
   test_decorated_class = TestDecoratedClass()
   self.assertEqual([2, 2, 3], test_decorated_class.return_params(1, 2, 3))
   decorators, target = tf_decorator.unwrap(test_decorated_class.return_params)
   self.assertEqual('test_decorator_increment_first_int_arg',
                    decorators[0].decorator_name)
   self.assertEqual([1, 2, 3], target(test_decorated_class, 1, 2, 3))
 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v2 = tf_export.get_v2_names(attr)
     for name in api_names_v2:
       v2_names.add(name)
  def testReorderFileNeedsUpdate(self):
    reordered_function_names = (
        tf_upgrade_v2.TFAPIChangeSpec().reordered_function_names)
    function_reorders = (
        tf_upgrade_v2.TFAPIChangeSpec().function_reorders)

    added_names_message = """Some function names in
self.reordered_function_names are not in reorders_v2.py.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    removed_names_message = """%s in self.reorders_v2 does not match
any name in self.reordered_function_names.
Please run the following commands to update reorders_v2.py:
bazel build tensorflow/tools/compatibility/update:generate_v2_reorders_map
bazel-bin/tensorflow/tools/compatibility/update/generate_v2_reorders_map
"""
    self.assertTrue(
        reordered_function_names.issubset(function_reorders),
        added_names_message)
    # function_reorders should contain reordered_function_names
    # and their TensorFlow V1 aliases.
    for name in function_reorders:
      # get other names for this function
      attr = get_symbol_for_name(tf.compat.v1, name)
      _, attr = tf_decorator.unwrap(attr)
      v1_names = tf_export.get_v1_names(attr)
      self.assertTrue(v1_names)
      v1_names = ["tf.%s" % n for n in v1_names]
      # check if any other name is in
      self.assertTrue(
          any(n in reordered_function_names for n in v1_names),
          removed_names_message % name)
Beispiel #9
0
def get_canonical_name_for_symbol(symbol, api_name=TENSORFLOW_API_NAME):
  """Get canonical name for the API symbol.

  Canonical name is the first non-deprecated endpoint name.

  Args:
    symbol: API function or class.
    api_name: API name (tensorflow or estimator).

  Returns:
    Canonical name for the API symbol (for e.g. initializers.zeros) if
    canonical name could be determined. Otherwise, returns None.
  """
  if not hasattr(symbol, '__dict__'):
    return None
  api_names_attr = API_ATTRS[api_name].names
  _, undecorated_symbol = tf_decorator.unwrap(symbol)
  if api_names_attr not in undecorated_symbol.__dict__:
    return None
  api_names = getattr(undecorated_symbol, api_names_attr)
  # TODO(annarev): may be add a separate deprecated attribute
  # for estimator names.
  deprecated_api_names = undecorated_symbol.__dict__.get(
      '_tf_deprecated_api_names', [])
  return get_canonical_name(api_names, deprecated_api_names)
Beispiel #10
0
def fn_args(fn):
  """Get argument names for function-like object.

  Args:
    fn: Function, or function-like object (e.g., result of `functools.partial`).

  Returns:
    `tuple` of string argument names.

  Raises:
    ValueError: if partial function has positionally bound arguments
  """
  _, fn = tf_decorator.unwrap(fn)

  # Handle callables.
  if hasattr(fn, '__call__') and tf_inspect.ismethod(fn.__call__):
    return tuple(tf_inspect.getargspec(fn.__call__).args)

  # Handle functools.partial and similar objects.
  if hasattr(fn, 'func') and hasattr(fn, 'keywords') and hasattr(fn, 'args'):
    # Handle nested partial.
    original_args = fn_args(fn.func)
    if not original_args:
      return tuple()

    return tuple([
        arg for arg in original_args[len(fn.args):]
        if arg not in set((fn.keywords or {}).keys())
    ])

  # Handle function.
  return tuple(tf_inspect.getargspec(fn).args)
Beispiel #11
0
def getfullargspec(obj):  # pylint: disable=redefined-builtin
  """TFDecorator-aware replacement for `inspect.getfullargspec`/`getargspec`.

  This wrapper uses `inspect.getfullargspec` if available and falls back to
  `inspect.getargspec` in Python 2.

  Args:
    obj: A callable, possibly decorated.

  Returns:
    The `FullArgSpec` that describes the signature of
    the outermost decorator that changes the callable's signature. If the
    callable is not decorated, `inspect.getfullargspec()` will be called
    directly on the callable.
  """
  if six.PY2:
    def spec_fn(target):
      argspecs = _inspect.getargspec(target)
      fullargspecs = FullArgSpec(
          args=argspecs.args,
          varargs=argspecs.varargs,
          varkw=argspecs.keywords,
          defaults=argspecs.defaults,
          kwonlyargs=[],
          kwonlydefaults=None,
          annotations={})
      return fullargspecs
  else:
    spec_fn = _inspect.getfullargspec

  decorators, target = tf_decorator.unwrap(obj)
  return next((d.decorator_argspec for d in decorators
               if d.decorator_argspec is not None), spec_fn(target))
def get_api_init_text(packages,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None):
  """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  if compat_api_versions is None:
    compat_api_versions = []
  module_code_builder = _ModuleInitCodeBuilder(output_package)
  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.

  def in_packages(m):
    return any(package in m for package in packages)

  for module in list(sys.modules.values()):
    # Only look at tensorflow modules.
    if (not module or not hasattr(module, '__name__') or
        module.__name__ is None or not in_packages(module.__name__)):
      continue
    # Do not generate __init__.py files for contrib modules for now.
    if (('.contrib.' in module.__name__ or module.__name__.endswith('.contrib'))
        and '.lite' not in module.__name__):
      continue

    for module_contents_name in dir(module):
      if (module.__name__ + '.' + module_contents_name
          in _SYMBOLS_TO_SKIP_EXPLICITLY):
        continue
      attr = getattr(module, module_contents_name)
      _, attr = tf_decorator.unwrap(attr)

      add_imports_for_symbol(
          module_code_builder, attr, module.__name__, module_contents_name,
          api_name, api_version)
      for compat_api_version in compat_api_versions:
        add_imports_for_symbol(
            module_code_builder, attr, module.__name__, module_contents_name,
            api_name, compat_api_version,
            _COMPAT_MODULE_TEMPLATE % compat_api_version)

  return module_code_builder.build()
 def visit(unused_path, unused_parent, children):
   """Visitor that collects TF 2.0 names."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
     for name in api_names_v2:
       v2_names.add(name)
 def testUnwrapReturnsListOfUniqueTFDecorators(self):
   decorators, _ = tf_decorator.unwrap(test_decorated_function)
   self.assertEqual(3, len(decorators))
   self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[1], tf_decorator.TFDecorator))
   self.assertTrue(isinstance(decorators[2], tf_decorator.TFDecorator))
   self.assertIsNot(decorators[0], decorators[1])
   self.assertIsNot(decorators[1], decorators[2])
   self.assertIsNot(decorators[2], decorators[0])
  def testRewrapMutatesAffectedFunction(self):

    def new_target(x):
      return x * 3

    self.assertEqual((1 * 2 + 1) ** 2, test_rewrappable_decorated(1))
    prev_target, _ = tf_decorator.unwrap(test_rewrappable_decorated)
    tf_decorator.rewrap(test_rewrappable_decorated, prev_target, new_target)
    self.assertEqual((1 * 3 + 1) ** 2, test_rewrappable_decorated(1))
 def visit(unused_path, unused_parent, children):
   """Visitor that collects rename strings to add to rename_line_set."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     api_names_v2 = tf_export.get_v2_names(attr)
     deprecated_api_names = set(api_names_v1) - set(api_names_v2)
     for name in deprecated_api_names:
       renames.add((name, get_canonical_name(api_names_v2, name)))
    def conversion_visitor(unused_path, unused_parent, children):
      for child in children:
        _, attr = tf_decorator.unwrap(child[1])
        if not tf_inspect.isfunction(attr):
          continue
        names_v1 = tf_export.get_v1_names(attr)
        arg_names_v1 = get_args(attr)

        for name in names_v1:
          tf_name = "tf.%s" % name
          if tf_name in function_warnings or tf_name in function_transformers:
            continue  # These require manual change
          if tf_name in v1_name_exceptions:
            continue
          # Assert that arg names after converting to v2 are present in
          # v2 function.
          # 1. First, create an input of the form:
          #    tf.foo(arg1=val1, arg2=val2, ...)
          args = ",".join(
              ["%s=%d" % (from_name, from_index)
               for from_index, from_name in enumerate(arg_names_v1)])
          text_input = "%s(%s)" % (tf_name, args)
          # 2. Convert the input to V2.
          _, _, _, text = self._upgrade(text_input)
          new_function_name, new_args = get_func_and_args_from_str(text)
          if new_function_name == "tf.compat.v1.%s" % name:
            if tf_name in keyword_renames:
              # If we rename arguments, new function must be available in 2.0.
              # We should not be using compat.v1 in this case.
              self.assertFalse(
                  "Function '%s' is not in 2.0 when converting\n%s\nto\n%s" %
                  (new_function_name, text_input, text))
            continue
          # 3. Verify V2 function and arguments.
          args_v2 = get_args(self.v2_symbols[new_function_name])
          args_v2.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v2,
                "Invalid argument '%s' in 2.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v2)))
          # 4. Verify that the argument exists in v1 as well.
          if new_function_name in set(["tf.nn.ctc_loss",
                                       "tf.saved_model.save"]):
            continue
          args_v1 = get_args(self.v1_symbols[new_function_name])
          args_v1.extend(v2_arg_exceptions)
          for new_arg in new_args:
            self.assertIn(
                new_arg, args_v1,
                "Invalid argument '%s' in 1.0 when converting\n%s\nto\n%s.\n"
                "Supported arguments: %s" % (
                    new_arg, text_input, text, str(args_v1)))
 def visit(unused_path, unused_parent, children):
   """Visitor that collects rename strings to add to rename_line_set."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names_v1 = attr.__dict__.get(_TENSORFLOW_API_ATTR_V1, [])
     api_names_v2 = attr.__dict__.get(_TENSORFLOW_API_ATTR, [])
     deprecated_api_names = set(api_names_v1) - set(api_names_v2)
     for name in deprecated_api_names:
       renames.add((name, get_canonical_name(api_names_v2, name)))
Beispiel #19
0
def _get_func_name(func):
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return "%s.%s" % (func.__self__.__name__, func.__name__)
    else:  # Probably a class instance with __call__
      return type(func)
  else:
    raise ValueError("Argument must be callable")
def serialize_keras_object(instance):
  _, instance = tf_decorator.unwrap(instance)
  if instance is None:
    return None
  if hasattr(instance, 'get_config'):
    return serialize_keras_class_and_config(instance.__class__.__name__,
                                            instance.get_config())
  if hasattr(instance, '__name__'):
    return instance.__name__
  else:
    raise ValueError('Cannot serialize', instance)
 def visit(unused_path, unused_parent, children):
   """Visitor that collects arguments for reordered functions."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     api_names_v1 = ['tf.%s' % name for name in api_names_v1]
     matches_function_names = any(
         name in function_names for name in api_names_v1)
     if matches_function_names:
       arg_list = tf_inspect.getargspec(attr)[0]
       for name in api_names_v1:
         function_to_args[name] = arg_list
 def conversion_visitor(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names = tf_export.get_v1_names(attr)
     for name in api_names:
       _, _, _, text = self._upgrade("tf." + name)
       if (text and
           not text.startswith("tf.compat.v1") and
           text not in self.v2_symbols):
         self.assertFalse(
             True, "Symbol %s generated from %s not in v2 API" % (
                 text, name))
Beispiel #23
0
def getfile(object):  # pylint: disable=redefined-builtin
  """TFDecorator-aware replacement for inspect.getfile."""
  unwrapped_object = tf_decorator.unwrap(object)[1]

  # Work around for the case when object is a stack frame
  # and only .pyc files are used. In this case, getfile
  # might return incorrect path. So, we get the path from f_globals
  # instead.
  if (hasattr(unwrapped_object, 'f_globals') and
      '__file__' in unwrapped_object.f_globals):
    return unwrapped_object.f_globals['__file__']
  return _inspect.getfile(unwrapped_object)
 def visit(unused_path, unused_parent, children):
   """Visitor that collects rename strings to add to rename_line_set."""
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     if not hasattr(attr, '__dict__'):
       continue
     api_names = attr.__dict__.get(tensorflow_api_attr, [])
     deprecated_api_names = attr.__dict__.get('_tf_deprecated_api_names', [])
     canonical_name = tf_export.get_canonical_name(
         api_names, deprecated_api_names)
     for name in deprecated_api_names:
       rename_line_set.add('    \'tf.%s\': \'tf.%s\'' % (name, canonical_name))
Beispiel #25
0
def get_func_code(func):
  """Returns func_code of passed callable."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func) or tf_inspect.ismethod(func):
      return six.get_function_code(func)
    elif hasattr(func, '__call__'):
      return six.get_function_code(func.__call__)
    else:
      raise ValueError('Unhandled callable, type=%s' % type(func))
  else:
    raise ValueError('Argument must be callable')
def get_func_name(func):
  """Returns name of passed callable."""
  _, func = tf_decorator.unwrap(func)
  if callable(func):
    if tf_inspect.isfunction(func):
      return func.__name__
    elif tf_inspect.ismethod(func):
      return '%s.%s' % (six.get_method_self(func).__class__.__name__,
                        six.get_method_function(func).__name__)
    else:  # Probably a class instance with __call__
      return str(type(func))
  else:
    raise ValueError('Argument must be callable')
Beispiel #27
0
def getargspec(obj):
  """TFDecorator-aware replacement for `inspect.getargspec`.

  Note: `getfullargspec` is recommended as the python 2/3 compatible
  replacement for this function.

  Args:
    obj: A function, partial function, or callable object, possibly decorated.

  Returns:
    The `ArgSpec` that describes the signature of the outermost decorator that
    changes the callable's signature, or the `ArgSpec` that describes
    the object if not decorated.

  Raises:
    ValueError: When callable's signature can not be expressed with
      ArgSpec.
    TypeError: For objects of unsupported types.
  """
  if isinstance(obj, functools.partial):
    return _get_argspec_for_partial(obj)

  decorators, target = tf_decorator.unwrap(obj)

  spec = next((d.decorator_argspec
               for d in decorators
               if d.decorator_argspec is not None), None)
  if spec:
    return spec

  try:
    # Python3 will handle most callables here (not partial).
    return _getargspec(target)
  except TypeError:
    pass

  if isinstance(target, type):
    try:
      return _getargspec(target.__init__)
    except TypeError:
      pass

    try:
      return _getargspec(target.__new__)
    except TypeError:
      pass

  # The `type(target)` ensures that if a class is received we don't return
  # the signature of its __call__ method.
  return _getargspec(type(target).__call__)
def has_deprecation_decorator(symbol):
  """Checks if given object has a deprecation decorator.

  We check if deprecation decorator is in decorators as well as
  whether symbol is a class whose __init__ method has a deprecation
  decorator.
  Args:
    symbol: Python object.

  Returns:
    True if symbol has deprecation decorator.
  """
  decorators, symbol = tf_decorator.unwrap(symbol)
  if contains_deprecation_decorator(decorators):
    return True
  if tf_inspect.isfunction(symbol):
    return False
  if not tf_inspect.isclass(symbol):
    return False
  if not hasattr(symbol, '__init__'):
    return False
  init_decorators, _ = tf_decorator.unwrap(symbol.__init__)
  return contains_deprecation_decorator(init_decorators)
Beispiel #29
0
def getargspec(object):  # pylint: disable=redefined-builtin
  """TFDecorator-aware replacement for inspect.getargspec.

  Args:
    object: A callable, possibly decorated.

  Returns:
    The `ArgSpec` that describes the signature of the outermost decorator that
    changes the callable's signature. If the callable is not decorated,
    `inspect.getargspec()` will be called directly on the callable.
  """
  decorators, target = tf_decorator.unwrap(object)
  return next((d.decorator_argspec for d in decorators
               if d.decorator_argspec is not None), _inspect.getargspec(target))
 def conversion_visitor(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names = tf_export.get_v1_names(attr)
     for name in api_names:
       _, _, _, text = self._upgrade("tf." + name)
       if (text and
           not text.startswith("tf.compat.v1") and
           text not in self.v2_symbols and
           # Builds currently install old version of estimator that doesn't
           # have some 2.0 symbols.
           not text.startswith("tf.estimator")):
         self.assertFalse(
             True, "Symbol %s generated from %s not in v2 API" % (
                 text, name))
Beispiel #31
0
def get_canonical_name_for_symbol(
    symbol, api_name=TENSORFLOW_API_NAME,
    add_prefix_to_v1_names=False):
  """Get canonical name for the API symbol.

  Args:
    symbol: API function or class.
    api_name: API name (tensorflow or estimator).
    add_prefix_to_v1_names: Specifies whether a name available only in V1
      should be prefixed with compat.v1.

  Returns:
    Canonical name for the API symbol (for e.g. initializers.zeros) if
    canonical name could be determined. Otherwise, returns None.
  """
  if not hasattr(symbol, '__dict__'):
    return None
  api_names_attr = API_ATTRS[api_name].names
  _, undecorated_symbol = tf_decorator.unwrap(symbol)
  if api_names_attr not in undecorated_symbol.__dict__:
    return None
  api_names = getattr(undecorated_symbol, api_names_attr)
  deprecated_api_names = undecorated_symbol.__dict__.get(
      '_tf_deprecated_api_names', [])

  canonical_name = get_canonical_name(api_names, deprecated_api_names)
  if canonical_name:
    return canonical_name

  # If there is no V2 canonical name, get V1 canonical name.
  api_names_attr = API_ATTRS_V1[api_name].names
  api_names = getattr(undecorated_symbol, api_names_attr)
  v1_canonical_name = get_canonical_name(api_names, deprecated_api_names)
  if add_prefix_to_v1_names:
    return 'compat.v1.%s' % v1_canonical_name
  return v1_canonical_name
Beispiel #32
0
 def _AddMember(member_name, member_obj, proto):
     """Add the child object to the object being constructed."""
     _, member_obj = tf_decorator.unwrap(member_obj)
     if (_SkipMember(parent, member_name) or isinstance(
             member_obj, deprecation.HiddenTfApiAttribute)):
         return
     if member_name == '__init__' or not six.ensure_str(
             member_name).startswith('_'):
         if tf_inspect.isroutine(member_obj):
             new_method = proto.member_method.add()
             new_method.name = member_name
             # If member_obj is a python builtin, there is no way to get its
             # argspec, because it is implemented on the C side. It also has no
             # func_code.
             if hasattr(member_obj, '__code__'):
                 new_method.argspec = _SanitizedArgSpec(member_obj)
         else:
             new_member = proto.member.add()
             new_member.name = member_name
             if tf_inspect.ismodule(member_obj):
                 new_member.mtype = "<type \'module\'>"
             else:
                 new_member.mtype = _NormalizeType(str(
                     type(member_obj)))
Beispiel #33
0
def register_dispatchers():
    """Constructs & registers OpDispatchers for ragged ops."""

    op_list = (_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
               _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
    for op in op_list:
        _, undecorated_op = tf_decorator.unwrap(op)
        if not hasattr(
                undecorated_op,
                tf_export.API_ATTRS[tf_export.TENSORFLOW_API_NAME].names):
            raise AssertionError('Expected %s to be an exported symbol '
                                 '(while adding a RaggedTensor dispatcher)')

    for op in _UNARY_ELEMENTWISE_OPS:
        UnaryRaggedElementwiseDispatcher(op).register(op)

    for op in _UNARY_LIST_ELEMENTWISE_OPS:
        UnaryRaggedElementwiseDispatcher(op, True).register(op)

    for op in _BINARY_ELEMENTWISE_OPS:
        BinaryRaggedElementwiseDispatcher(op).register(op)

    for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
        RaggedDispatcher(original_op, ragged_op, args).register(original_op)
Beispiel #34
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            autograph=False,
                            autograph_options=None,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None,
                            collections=None,
                            capture_by_value=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    autograph_options: additional knobs to control when `autograph=True`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.
    collections: a dictionary of collections this FuncGraph should start
      with. If not specified (None), the FuncGraph will read (but not write to)
      the outer graph's collections that are not whitelisted, and both
      read and write to the outer graph's collections that are whitelisted.
      The current whitelisted collections are the global variables, the
      local variables, and the trainable variables.
      Defaults to None.
    capture_by_value: An optional boolean. If True, the func graph will capture
      Variables by value instead of reference. By default inherit from outer
      graphs, and failing that will default to False.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name,
                               collections=collections,
                               capture_by_value=capture_by_value)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies()
    else:
        control_manager = ops.NullContextmanager()
    with func_graph.as_default(), control_manager as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None:
            args = signature
            kwargs = {}

        # Creates and names placeholders for all arguments.
        func_args = _get_defun_inputs_from_args(args, arg_names)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

        # Convert all Tensors into TensorSpecs before saving the structured inputs.
        # If storing pure concrete functions that are not called through polymorphic
        # functions, we don't have access to FunctionSpec, so we need to call the
        # TensorSpecs by their `arg_names` for later binding.
        func_graph.structured_input_signature = (
            convert_structure_to_signature(func_args, arg_names),
            convert_structure_to_signature(func_kwargs))

        flat_func_args = nest.flatten(func_args)
        flat_func_kwargs = nest.flatten(func_kwargs)
        # Temporarily set inputs to allow graph building code to inspect
        # them. Reassigned below.
        func_graph.inputs = [
            arg for arg in flat_func_args + flat_func_kwargs
            if isinstance(arg, ops.Tensor)
        ]

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args, flat_func_args)
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   flat_func_kwargs)

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            elif not isinstance(x, tensor_array_ops.TensorArray):
                try:
                    x = ops.convert_to_tensor_or_composite(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        this_tape = tape.push_new_tape()
        try:
            if autograph:
                from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
                _, original_func = tf_decorator.unwrap(python_func)

                def wrapper(*args, **kwargs):
                    # Note: functions annotated with @tf.function should always be
                    # converted even though they would meet autograph's whitelisting
                    # criteria.
                    # If this assumption is ever broken, converted_call will need to
                    # handle the possibility of original_func still being a shim, e.g.
                    # bound to WeakrefSelf.
                    return autograph.converted_call(
                        original_func, None,
                        autograph.ConversionOptions(
                            verbose=autograph.Verbosity.BRIEF,
                            recursive=True,
                            strip_decorators=(def_function.function, ),
                            optional_features=autograph_options,
                            force_conversion=True,
                        ), args, kwargs)

                # Wrapping around a decorator allows checks like tf_inspect.getargspec
                # to be accurate.
                converted_func = tf_decorator.make_decorator(
                    original_func, wrapper)
                tf_decorator.rewrap(python_func, original_func, converted_func)

            func_outputs = python_func(*func_args, **func_kwargs)

            # invariant: `func_outputs` contains only Tensors, IndexedSlices,
            # SparseTensors, TensorArrays and `None`s.
            func_outputs = nest.map_structure(convert, func_outputs)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            tape.pop_tape(this_tape)
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        tape_variables = this_tape.watched_variables()
        arg_variables = set()
        inputs = []
        for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
            if isinstance(arg, resource_variable_ops.ResourceVariable):
                # Even if an argument variable was not used in the function, we've
                # already manually captured the resource Tensor when creating argument
                # placeholders.
                resource_placeholder = func_graph.captures.pop(
                    arg.handle, None)
                if resource_placeholder is None:
                    continue
                arg_variables.add(arg)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in tape_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values())

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    if add_control_dependencies:
        func_graph.control_outputs.extend(control_manager.ops_which_must_run)


# Register any other functions defined in the graph.
    with ops.init_scope():
        if context.executing_eagerly():
            for f in func_graph._functions.values():  # pylint: disable=protected-access
                # TODO(ashankar): What about the gradient registry?
                context.add_function(f._c_func.func)  # pylint: disable=protected-access

    return func_graph
Beispiel #35
0
def isroutine(object):  # pylint: disable=redefined-builtin
    """TFDecorator-aware replacement for inspect.isroutine."""
    return _inspect.isroutine(tf_decorator.unwrap(object)[1])
Beispiel #36
0
def serialize_keras_object(instance):
    """Serialize a Keras object into a JSON-compatible representation.

  Calls to `serialize_keras_object` while underneath the
  `SharedObjectSavingScope` context manager will cause any objects re-used
  across multiple layers to be saved with a special shared object ID. This
  allows the network to be re-created properly during deserialization.

  Args:
    instance: The object to serialize.

  Returns:
    A dict-like, JSON-compatible representation of the object's config.
  """
    _, instance = tf_decorator.unwrap(instance)
    if instance is None:
        return None

    # pylint: disable=protected-access
    #
    # For v1 layers, checking supports_masking is not enough. We have to also
    # check whether compute_mask has been overridden.
    supports_masking = (getattr(instance, 'supports_masking', False)
                        or (hasattr(instance, 'compute_mask')
                            and not is_default(instance.compute_mask)))
    if supports_masking and is_default(instance.get_config):
        warnings.warn(
            'Custom mask layers require a config and must override '
            'get_config. When loading, the custom mask layer must be '
            'passed to the custom_objects argument.',
            category=CustomMaskWarning)
    # pylint: enable=protected-access

    if hasattr(instance, 'get_config'):
        name = get_registered_name(instance.__class__)
        try:
            config = instance.get_config()
        except NotImplementedError as e:
            if _SKIP_FAILED_SERIALIZATION:
                return serialize_keras_class_and_config(
                    name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
            raise e
        serialization_config = {}
        for key, item in config.items():
            if isinstance(item, str):
                serialization_config[key] = item
                continue

            # Any object of a different type needs to be converted to string or dict
            # for serialization (e.g. custom functions, custom classes)
            try:
                serialized_item = serialize_keras_object(item)
                if isinstance(serialized_item,
                              dict) and not isinstance(item, dict):
                    serialized_item['__passive_serialization__'] = True
                serialization_config[key] = serialized_item
            except ValueError:
                serialization_config[key] = item

        name = get_registered_name(instance.__class__)
        return serialize_keras_class_and_config(name, serialization_config,
                                                instance)
    if hasattr(instance, '__name__'):
        return get_registered_name(instance)
    raise ValueError('Cannot serialize', instance)
Beispiel #37
0
def get_api_imports():
    """Get a map from destination module to formatted imports.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: List of strings representing module imports
          (for e.g. 'from foo import bar') and constant
          assignments (for e.g. 'FOO = 123').
  """
    module_imports = collections.defaultdict(list)
    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.
    for module in sys.modules.values():
        # Only look at tensorflow modules.
        if not module or 'tensorflow.' not in module.__name__:
            continue

        for module_contents_name in dir(module):
            attr = getattr(module, module_contents_name)

            # If attr is _tf_api_constants attribute, then add the constants.
            if module_contents_name == _API_CONSTANTS_ATTR:
                for exports, value in attr:
                    for export in exports:
                        names = ['tf'] + export.split('.')
                        dest_module = '.'.join(names[:-1])
                        import_str = format_import(module.__name__, value,
                                                   names[-1])
                        module_imports[dest_module].append(import_str)
                continue

            _, attr = tf_decorator.unwrap(attr)
            # If attr is a symbol with _tf_api_names attribute, then
            # add import for it.
            if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
                # The same op might be accessible from multiple modules.
                # We only want to consider location where function was defined.
                if attr.__module__ != module.__name__:
                    continue

                for export in attr._tf_api_names:  # pylint: disable=protected-access
                    names = ['tf'] + export.split('.')
                    dest_module = '.'.join(names[:-1])
                    import_str = format_import(module.__name__,
                                               module_contents_name, names[-1])
                    module_imports[dest_module].append(import_str)

    # Import all required modules in their parent modules.
    # For e.g. if we import 'tf.foo.bar.Value'. Then, we also
    # import 'bar' in 'tf.foo'.
    dest_modules = set(module_imports.keys())
    for dest_module in dest_modules:
        dest_module_split = dest_module.split('.')
        for dest_submodule_index in range(1, len(dest_module_split)):
            dest_submodule = '.'.join(dest_module_split[:dest_submodule_index])
            submodule_import = format_import(
                '', dest_module_split[dest_submodule_index],
                dest_module_split[dest_submodule_index])
            if submodule_import not in module_imports[dest_submodule]:
                module_imports[dest_submodule].append(submodule_import)

    return module_imports
Beispiel #38
0
    def SmartSet(self, obj, attr_name, new_attr):
        """Replace obj.attr_name with new_attr.

    This method is smart and works at the module, class, and instance level
    while preserving proper inheritance. It will not stub out C types however
    unless that has been explicitly allowed by the type.

    This method supports the case where attr_name is a staticmethod or a
    classmethod of obj.

    Notes:
      - If obj is an instance, then it is its class that will actually be
        stubbed. Note that the method Set() does not do that: if obj is
        an instance, it (and not its class) will be stubbed.
      - The stubbing is using the builtin getattr and setattr. So, the __get__
        and __set__ will be called when stubbing (TODO: A better idea would
        probably be to manipulate obj.__dict__ instead of getattr() and
        setattr()).

    Args:
      obj: The object whose attributes we want to modify.
      attr_name: The name of the attribute to modify.
      new_attr: The new value for the attribute.

    Raises:
      AttributeError: If the attribute cannot be found.
    """
        _, obj = tf_decorator.unwrap(obj)
        if (tf_inspect.ismodule(obj) or
            (not tf_inspect.isclass(obj) and attr_name in obj.__dict__)):
            orig_obj = obj
            orig_attr = getattr(obj, attr_name)
        else:
            if not tf_inspect.isclass(obj):
                mro = list(tf_inspect.getmro(obj.__class__))
            else:
                mro = list(tf_inspect.getmro(obj))

            mro.reverse()

            orig_attr = None
            found_attr = False

            for cls in mro:
                try:
                    orig_obj = cls
                    orig_attr = getattr(obj, attr_name)
                    found_attr = True
                except AttributeError:
                    continue

            if not found_attr:
                raise AttributeError('Attribute not found.')

        # Calling getattr() on a staticmethod transforms it to a 'normal' function.
        # We need to ensure that we put it back as a staticmethod.
        old_attribute = obj.__dict__.get(attr_name)
        if old_attribute is not None and isinstance(old_attribute,
                                                    staticmethod):
            orig_attr = staticmethod(orig_attr)

        self.stubs.append((orig_obj, attr_name, orig_attr))
        setattr(orig_obj, attr_name, new_attr)
Beispiel #39
0
def get_api_init_text():
  """Get a map from destination module to __init__.py code for that module.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  module_code_builder = _ModuleInitCodeBuilder()

  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.
  # print(type(sys.modules.values()), type(sys.modules))
  # assert False
  for module in sys.modules.values():
    # Only look at tensorflow modules.
    # print(module, "HRE I AM")
    if (not module or not hasattr(module, "__name__") or
        'tensorflow.' not in module.__name__):
      continue
    # Do not generate __init__.py files for contrib modules for now.
    if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
      continue

    for module_contents_name in dir(module):
      attr = getattr(module, module_contents_name)
      # If attr is _tf_api_constants attribute, then add the constants.
      if module_contents_name == _API_CONSTANTS_ATTR:
        for exports, value in attr:
          for export in exports:
            names = export.split('.')
            dest_module = '.'.join(names[:-1])
            module_code_builder.add_import(
                -1, dest_module, module.__name__, value, names[-1])
        continue

      _, attr = tf_decorator.unwrap(attr)
      # If attr is a symbol with _tf_api_names attribute, then
      # add import for it.
      if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
        for export in attr._tf_api_names:  # pylint: disable=protected-access
          names = export.split('.')
          dest_module = '.'.join(names[:-1])
          if "reshape" in str(attr) or "reshape" in str(module_contents_name):        # If it is reshape here, don't add it in the file, just continue, not sure how to overwrite, but this is fine
              continue
          
          if "conv2d" in str(attr) or "conv2d" in str(module_contents_name):        # If it is reshape here, don't add it in the file, just continue, not sure how to overwrite, but this is fine
              if str(attr._tf_api_names[0]) == "nn.conv2d" and (not "tensor_shape" in module.__name__):   # there are several conv2d modules, only when this is
                continue
          
          if "nn.relu" in str(attr._tf_api_names) and (not "tensor_shape" in module.__name__):
            continue

          if "log" in str(module_contents_name) and names[-1] == "log" and len(names) == 1 and "gen_math_ops" in module.__name__:
            continue

          if "equal" in str(module_contents_name) and names[-1] == "equal" and "gen_math_ops" in module.__name__:
            continue

          if "add" in str(module_contents_name) and names[-1] == "add" and "gen_math_ops" in module.__name__:
            continue

          if "greater" in str(module_contents_name) and names[-1] == "greater" and "gen_math_ops" in module.__name__:
            continue
          
          if "exp" in str(module_contents_name) and names[-1] == "exp":    #  I added it to tensor_shape file and problem was solved, its being imported from tensor_shape in __init__.py # and ("gen_math_ops" in module.__name__ or "standard_ops" in module.__name__):
            continue
          
          if "tile" in str(module_contents_name) and names[-1] == "tile" and "gen_array_ops" in module.__name__:    #  I added it to tensor_shape file and problem was solved, its being imported from tensor_shape in __init__.py # and ("gen_math_ops" in module.__name__ or "standard_ops" in module.__name__):
            continue

          if "lgamma" in str(module_contents_name) and names[-1] == "lgamma" and "gen_math_ops" in module.__name__:    #  I added it to tensor_shape file and problem was solved, its being imported from tensor_shape in __init__.py # and ("gen_math_ops" in module.__name__ or "standard_ops" in module.__name__):
            continue
          
          if "maximum" in str(module_contents_name) and names[-1] == "maximum" and "gen_math_ops" in module.__name__:    #  I added it to tensor_shape file and problem was solved, its being imported from tensor_shape in __init__.py # and ("gen_math_ops" in module.__name__ or "standard_ops" in module.__name__):
            continue

          if "minimum" in str(module_contents_name) and names[-1] == "minimum" and "gen_math_ops" in module.__name__:    #  I added it to tensor_shape file and problem was solved, its being imported from tensor_shape in __init__.py # and ("gen_math_ops" in module.__name__ or "standard_ops" in module.__name__):
            continue

          if "lrn" in str(module_contents_name) and "gen_nn_ops" in module.__name__:    # names[-1] can be lrn or local_response_normalization I added it to nn_ops file
            # print(names)
            # print(attr, module_contents_name, attr._tf_api_names, names, dest_module, id(attr), module.__name__, "HELLLO")
            # nw = dest_module + '.' + names[-1]
            # print(len(module_code_builder._dest_import_to_id), nw in module_code_builder._dest_import_to_id, "dekhio", len(module_code_builder._dest_import_to_id), nw in module_code_builder._dest_import_to_id)
            continue

          module_code_builder.add_import(
              id(attr), dest_module, module.__name__, module_contents_name,
              names[-1])
          

  # Import all required modules in their parent modules.
  # For e.g. if we import 'foo.bar.Value'. Then, we also
  # import 'bar' in 'foo'.
  imported_modules = set(module_code_builder.module_imports.keys())
  # print(imported_modules, "DEKHO")
  for module in imported_modules:
    if not module:
      continue
    module_split = module.split('.')
    parent_module = ''  # we import submodules in their parent_module

    for submodule_index in range(len(module_split)):
      import_from = _OUTPUT_MODULE
      if submodule_index > 0:
        parent_module += ('.' + module_split[submodule_index-1] if parent_module
                          else module_split[submodule_index-1])
        import_from += '.' + parent_module
      module_code_builder.add_import(
          -1, parent_module, import_from,
          module_split[submodule_index], module_split[submodule_index])

  return module_code_builder.build()
Beispiel #40
0
 def symbol_collector(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v2 = get_v2_names(attr)
     for name in api_names_v2:
       cls.v2_symbols["tf." + name] = attr
Beispiel #41
0
def tf_convert(f, ctx, convert_by_default=True, user_requested=False):
    """Decorator that applies AutoGraph to a function.

  Use in internal APIs.

  This API is suitable for high order functions internal to the TensorFlow API,
  and more generally any function to which AutoGraph is not applied.

  Guidance: `convert` was a decorator meant for use directly by developers, but
  most of today's uses go through `tf.function`. `tf_convert` is to be called
  from high order functions internal to TF. By default, all the internal
  TensorFlow functions are skipped when AutoGraph processes the code. This may
  lead to user-supplied functions to be incorrectly skipped as well.
  `tf_convert` helps avoid that. See the following example for more details.

  ```
  =====tf_internal_module.py=====

  def unconverted(input_fn):
    return input_fn()

  def converted(input_fn):
    return tf.__internal__.autograph.tf_convert(
       input_fn, ctx=tf.__internal__.autograph.control_status_ctx())()

  ======user_module.py======

  @tf.function
  def foo(input_fn)
    return unconverted(input_fn)

  @tf.function
  def bar(input_fn)
    return converted(input_fn)

  @tf.function(autograph=False)
  def baz(input_fn)
    return converted(input_fn)
  ```

  The `foo` method above will execute the `input_fn` without autograph
  conversion, while the `bar` method will run an autographed `input_fn`. The
  `baz` method will run an unconverted `input_fn`, since `tf_convert` respect
  the control status context.

  Note that both methods in `tf_internal_module` are skipped by autograph when
  tracing the `tf.function`. The configuration of whether a module/package
  should be skipped by autograph is controlled in
  tensorflow/python/autograph/core/config.py.

  Args:
    f: Callable.
    ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used.
    convert_by_default: bool, whether to use AutoGraph when the context doesn't
      specify.
    user_requested: bool, whether to ignore the conversion allowlist. See
      ConversionOptions.user_requested.

  Returns:
    Either `f or the converted version of `f`.
  """

    if is_autograph_artifact(f):
        return f
    f_wrapper = f
    decorators, f = tf_decorator.unwrap(f)

    # TODO(mdan): Grab features from context.
    # Note: we pass the original context through to convert to properly handle the
    # following scenario, which can be used inside TF implementations:
    #
    #   ctx = ag_ctx.control_status_ctx()
    #   @function(autograph=False)  # Low-level graph code
    #   def inner_fn():
    #     # The context is disabled here, but should be enabled in user user_fn
    #     tf_convert(user_fn, ctx=ctx)
    if ctx.status == ag_ctx.Status.ENABLED:
        wrapper_factory = convert(recursive=True,
                                  user_requested=user_requested,
                                  conversion_ctx=ctx)
    elif ctx.status == ag_ctx.Status.DISABLED:
        wrapper_factory = do_not_convert
    elif ctx.status == ag_ctx.Status.UNSPECIFIED:
        if convert_by_default:
            wrapper_factory = convert(recursive=True,
                                      user_requested=user_requested,
                                      conversion_ctx=ctx)
        else:
            wrapper_factory = call_with_unspecified_conversion_status
    else:
        assert False, 'This switch contains all possible cases!'
    wrapper = wrapper_factory(f)

    if decorators:
        wrapper = tf_decorator.rewrap(f_wrapper, f, wrapper)

    return autograph_artifact(wrapper)
Beispiel #42
0
 def testUnwrapContextManager(self):
     decorators, target = tf_decorator.unwrap(test_params_and_defaults)
     self.assertEqual(1, len(decorators))
     self.assertTrue(isinstance(decorators[0], tf_decorator.TFDecorator))
     self.assertEqual('contextmanager', decorators[0].decorator_name)
     self.assertFalse(isinstance(target, tf_decorator.TFDecorator))
Beispiel #43
0
 def symbol_collector_v1(unused_path, unused_parent, children):
   for child in children:
     _, attr = tf_decorator.unwrap(child[1])
     api_names_v1 = tf_export.get_v1_names(attr)
     for name in api_names_v1:
       cls.v1_symbols["tf." + name] = attr
Beispiel #44
0
 def testUnwrapReturnsFinalFunctionAsTarget(self):
     self.assertEqual((4 + 1) * 2, test_decorated_function(4))
     _, target = tf_decorator.unwrap(test_decorated_function)
     self.assertTrue(tf_inspect.isfunction(target))
     self.assertEqual(4 * 2, target(4))
def get_api_init_text(packages,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None):
    """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
    if compat_api_versions is None:
        compat_api_versions = []
    module_code_builder = _ModuleInitCodeBuilder(output_package)

    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.

    def in_packages(m):
        return any(package in m for package in packages)

    for module in list(sys.modules.values()):
        # Only look at tensorflow modules.
        if (not module or not hasattr(module, '__name__')
                or module.__name__ is None
                or not in_packages(module.__name__)):
            continue
        # Do not generate __init__.py files for contrib modules for now.
        if (('.contrib.' in module.__name__
             or module.__name__.endswith('.contrib'))
                and '.lite' not in module.__name__):
            continue

        for module_contents_name in dir(module):
            if (module.__name__ + '.' + module_contents_name
                    in _SYMBOLS_TO_SKIP_EXPLICITLY):
                continue
            attr = getattr(module, module_contents_name)
            _, attr = tf_decorator.unwrap(attr)

            add_imports_for_symbol(module_code_builder, attr, module.__name__,
                                   module_contents_name, api_name, api_version)
            for compat_api_version in compat_api_versions:
                add_imports_for_symbol(
                    module_code_builder, attr, module.__name__,
                    module_contents_name, api_name, compat_api_version,
                    _COMPAT_MODULE_TEMPLATE % compat_api_version)

    return module_code_builder.build()
Beispiel #46
0
def get_api_init_text(package, api_name):
  """Get a map from destination module to __init__.py code for that module.

  Args:
    package: Base python package containing python with target tf_export
      decorators.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
  module_code_builder = _ModuleInitCodeBuilder()

  # Traverse over everything imported above. Specifically,
  # we want to traverse over TensorFlow Python modules.
  for module in list(sys.modules.values()):
    # Only look at tensorflow modules.
    if (not module or not hasattr(module, '__name__') or
        package not in module.__name__):
      continue
    # Do not generate __init__.py files for contrib modules for now.
    if '.contrib.' in module.__name__ or module.__name__.endswith('.contrib'):
      continue

    for module_contents_name in dir(module):
      if (module.__name__ + '.' + module_contents_name
          in _SYMBOLS_TO_SKIP_EXPLICITLY):
        continue
      attr = getattr(module, module_contents_name)

      # If attr is _tf_api_constants attribute, then add the constants.
      if module_contents_name == API_ATTRS[api_name].constants:
        for exports, value in attr:
          for export in exports:
            names = export.split('.')
            dest_module = '.'.join(names[:-1])
            module_code_builder.add_import(
                -1, dest_module, module.__name__, value, names[-1])
        continue

      _, attr = tf_decorator.unwrap(attr)
      # If attr is a symbol with _tf_api_names attribute, then
      # add import for it.
      if (hasattr(attr, '__dict__') and
          API_ATTRS[api_name].names in attr.__dict__):
        for export in getattr(attr, API_ATTRS[api_name].names):  # pylint: disable=protected-access
          names = export.split('.')
          dest_module = '.'.join(names[:-1])
          module_code_builder.add_import(
              id(attr), dest_module, module.__name__, module_contents_name,
              names[-1])

  # Import all required modules in their parent modules.
  # For e.g. if we import 'foo.bar.Value'. Then, we also
  # import 'bar' in 'foo'.
  imported_modules = set(module_code_builder.module_imports.keys())
  import_from = '.'
  for module in imported_modules:
    if not module:
      continue
    module_split = module.split('.')
    parent_module = ''  # we import submodules in their parent_module

    for submodule_index in range(len(module_split)):
      if submodule_index > 0:
        parent_module += ('.' + module_split[submodule_index-1] if parent_module
                          else module_split[submodule_index-1])
      module_code_builder.add_import(
          -1, parent_module, import_from,
          module_split[submodule_index], module_split[submodule_index])

  return module_code_builder.build()
def get_api_init_text(packages,
                      output_package,
                      api_name,
                      api_version,
                      compat_api_versions=None,
                      lazy_loading=_LAZY_LOADING,
                      use_relative_imports=False):
    """Get a map from destination module to __init__.py code for that module.

  Args:
    packages: Base python packages containing python with target tf_export
      decorators.
    output_package: Base output python package where generated API will be
      added.
    api_name: API you want to generate (e.g. `tensorflow` or `estimator`).
    api_version: API version you want to generate (1 or 2).
    compat_api_versions: Additional API versions to generate under compat/
      directory.
    lazy_loading: Boolean flag. If True, a lazy loading `__init__.py` file is
      produced and if `False`, static imports are used.
    use_relative_imports: True if we should use relative imports when
      importing submodules.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
    if compat_api_versions is None:
        compat_api_versions = []
    module_code_builder = _ModuleInitCodeBuilder(output_package, api_version,
                                                 lazy_loading,
                                                 use_relative_imports)

    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.

    def in_packages(m):
        return any(package in m for package in packages)

    for module in list(sys.modules.values()):
        # Only look at tensorflow modules.
        if (not module or not hasattr(module, '__name__')
                or module.__name__ is None
                or not in_packages(module.__name__)):
            continue
        # Do not generate __init__.py files for contrib modules for now.
        if (('.contrib.' in module.__name__
             or module.__name__.endswith('.contrib'))
                and '.lite' not in module.__name__):
            continue

        for module_contents_name in dir(module):
            if (module.__name__ + '.' + module_contents_name
                    in _SYMBOLS_TO_SKIP_EXPLICITLY):
                continue
            attr = getattr(module, module_contents_name)
            _, attr = tf_decorator.unwrap(attr)

            add_imports_for_symbol(module_code_builder, attr, module.__name__,
                                   module_contents_name, api_name, api_version)
            for compat_api_version in compat_api_versions:
                add_imports_for_symbol(
                    module_code_builder, attr, module.__name__,
                    module_contents_name, api_name, compat_api_version,
                    _COMPAT_MODULE_TEMPLATE % compat_api_version)

    # Include compat.vN-1 under compat.vN.
    # For e.g. import compat.v1 under compat.v2.compat
    for version in compat_api_versions:
        if version - 1 in compat_api_versions:
            prev_version = 'v%d' % (version - 1)
            module_code_builder.add_import(
                symbol=None,
                source_module_name='%s.compat' % output_package,
                source_name=prev_version,
                dest_module_name='compat.v%d.compat' % version,
                dest_name=prev_version)

    return module_code_builder.build()
Beispiel #48
0
def signature(obj, *, follow_wrapped=True):
  """TFDecorator-aware replacement for inspect.signature."""
  return _inspect.signature(
      tf_decorator.unwrap(obj)[1], follow_wrapped=follow_wrapped)
Beispiel #49
0
def func_graph_from_py_func(name,
                            python_func,
                            args,
                            kwargs,
                            signature=None,
                            func_graph=None,
                            experimental_autograph=False,
                            add_control_dependencies=True,
                            arg_names=None,
                            op_return_value=None):
    """Returns a `FuncGraph` generated from `python_func`.

  Args:
    name: an identifier for the function.
    python_func: the Python function to trace.
    args: the positional args with which the Python function should be called;
      ignored if a signature is provided.
    kwargs: the keyword args with which the Python function should be called;
      ignored if a signature is provided.
    signature: a possibly nested sequence of `TensorSpecs` specifying the shapes
      and dtypes of the arguments. When a signature is provided, `args` and
      `kwargs` are ignored, and `python_func` is traced with Tensors conforming
      to `signature`. If `None`, the shapes and dtypes are inferred from the
      inputs.
    func_graph: Optional. An instance of FuncGraph. If provided, we will use
      this graph else a new one is built and returned.
    experimental_autograph: whether to use autograph to compile `python_func`.
      See https://www.tensorflow.org/guide/autograph for more information.
    add_control_dependencies: If True, automatically adds control dependencies
      to ensure program order matches execution order and stateful ops always
      execute.
    arg_names: Optional list of argument names, used to give input placeholders
      recognizable names.
    op_return_value: Optional. A Tensor. If set and `python_func` returns
      Operations, those return values will be replaced with this value. If not
      set, returning an Operation triggers an error.

  Returns:
    A FuncGraph.

  Raises:
    TypeError: If any of `python_func`'s return values is neither `None` nor a
      `Tensor`.
  """
    if op_return_value is not None:
        assert isinstance(op_return_value, ops.Tensor), op_return_value
    if func_graph is None:
        func_graph = FuncGraph(name)
    assert isinstance(func_graph, FuncGraph)
    if add_control_dependencies:
        control_manager = AutomaticControlDependencies
    else:
        control_manager = ops.NullContextmanager
    with func_graph.as_default(), control_manager() as a:
        current_scope = variable_scope.get_variable_scope()
        default_use_recource = current_scope.use_resource
        current_scope.set_use_resource(True)

        if signature is not None:
            args = signature
            kwargs = {}

        # Creates and names placeholders for all arguments.
        func_args = _get_defun_inputs_from_args(args, arg_names)
        func_kwargs = _get_defun_inputs_from_kwargs(kwargs)

        # Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
        # Variables to help check whether mutation happens in calling the function
        # Copy the recursive list, tuple and map structure, but not base objects
        func_args_before = nest.pack_sequence_as(func_args,
                                                 nest.flatten(func_args))
        func_kwargs_before = nest.pack_sequence_as(func_kwargs,
                                                   nest.flatten(func_kwargs))

        def convert(x):
            """Converts a function output to a Tensor."""
            if x is None:
                return None
            if op_return_value is not None and isinstance(x, ops.Operation):
                # TODO(b/79881896): we currently can't capture external control deps, so
                # this won't work if x needs to be captured (i.e. if python_func returns
                # captured Operations).
                with ops.control_dependencies([x]):
                    x = array_ops.identity(op_return_value)
            else:
                try:
                    x = ops.convert_to_tensor_or_indexed_slices(x)
                except (ValueError, TypeError):
                    raise TypeError(
                        "To be compatible with tf.contrib.eager.defun, Python functions "
                        "must return zero or more Tensors; in compilation of %s, found "
                        "return value of type %s, which is not a Tensor." %
                        (str(python_func), type(x)))
            if add_control_dependencies:
                x = a.mark_as_return(x)
            return x

        this_tape = tape.push_new_tape()
        try:
            if experimental_autograph:
                from tensorflow.python import autograph  # pylint: disable=g-import-not-at-top
                _, original_func = tf_decorator.unwrap(python_func)

                # AutoGraph does not yet rebind the returned method, and must receive
                # `self` explicitly.
                # TODO(mdan): Have the result automatically bind it instead.
                if (tf_inspect.ismethod(original_func)
                        and hasattr(original_func, "__self__")):
                    effective_func_args = (
                        original_func.__self__, ) + func_args
                else:
                    effective_func_args = func_args

                func_outputs = autograph.converted_call(
                    original_func, None,
                    autograph.ConversionOptions(
                        verbose=True,
                        recursive=True,
                        strip_decorators=(function.defun,
                                          def_function.function),
                        optional_features=(),
                    ), *effective_func_args, **func_kwargs)
            else:
                func_outputs = python_func(*func_args, **func_kwargs)
            # invariant: `func_outputs` contains only Tensors and `None`s.
            func_outputs = nest.map_structure(convert, func_outputs)

            check_mutation(func_args_before, func_args)
            check_mutation(func_kwargs_before, func_kwargs)
        finally:
            tape.pop_tape(this_tape)
            current_scope.set_use_resource(default_use_recource)

        # Variables in `func_args`, `func_kwargs` should be explicit inputs
        # to the function, not captured inputs.
        tape_variables = this_tape.watched_variables()
        arg_variables = set()
        inputs = []
        for arg in nest.flatten(func_args) + nest.flatten(func_kwargs):
            if isinstance(arg, resource_variable_ops.ResourceVariable):
                # Even if an argument variable was not used in the function, we've
                # already manually captured the resource Tensor when creating argument
                # placeholders.
                resource_placeholder = func_graph.captures.pop(arg.handle)
                arg_variables.add(arg)
                inputs.append(resource_placeholder)
            elif isinstance(arg, ops.Tensor):
                inputs.append(arg)
        variables = [v for v in tape_variables if v not in arg_variables]
        func_graph.inputs = inputs + list(func_graph.captures.values())

        func_graph.structured_outputs = func_outputs
        # Returning a closed-over tensor does not trigger convert_to_tensor.
        func_graph.outputs.extend(
            func_graph.capture(x)
            for x in flatten(func_graph.structured_outputs) if x is not None)

        func_graph.variables = variables

    # Register any other functions defined in the graph.
    with ops.init_scope():
        if context.executing_eagerly():
            for f in func_graph._functions.values():  # pylint: disable=protected-access
                # TODO(ashankar): What about the gradient registry?
                context.add_function(f._c_func.func)  # pylint: disable=protected-access

    return func_graph
Beispiel #50
0
def get_api_init_text():
    """Get a map from destination module to __init__.py code for that module.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: (string) text that should be in __init__.py files for
        corresponding modules.
  """
    module_code_builder = _ModuleInitCodeBuilder()

    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.
    for module in sys.modules.values():
        # Only look at tensorflow modules.
        if (not module or not hasattr(module, '__name__')
                or 'tensorflow.' not in module.__name__):
            continue
        # Do not generate __init__.py files for contrib modules for now.
        if '.contrib.' in module.__name__ or module.__name__.endswith(
                '.contrib'):
            continue

        for module_contents_name in dir(module):
            attr = getattr(module, module_contents_name)

            # If attr is _tf_api_constants attribute, then add the constants.
            if module_contents_name == _API_CONSTANTS_ATTR:
                for exports, value in attr:
                    for export in exports:
                        names = export.split('.')
                        dest_module = '.'.join(names[:-1])
                        module_code_builder.add_import(-1, dest_module,
                                                       module.__name__, value,
                                                       names[-1])
                continue

            _, attr = tf_decorator.unwrap(attr)
            # If attr is a symbol with _tf_api_names attribute, then
            # add import for it.
            if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
                for export in attr._tf_api_names:  # pylint: disable=protected-access
                    names = export.split('.')
                    dest_module = '.'.join(names[:-1])
                    module_code_builder.add_import(id(attr), dest_module,
                                                   module.__name__,
                                                   module_contents_name,
                                                   names[-1])

    # Import all required modules in their parent modules.
    # For e.g. if we import 'foo.bar.Value'. Then, we also
    # import 'bar' in 'foo'.
    imported_modules = set(module_code_builder.module_imports.keys())
    for module in imported_modules:
        if not module:
            continue
        module_split = module.split('.')
        parent_module = ''  # we import submodules in their parent_module

        for submodule_index in range(len(module_split)):
            import_from = _OUTPUT_MODULE
            if submodule_index > 0:
                parent_module += ('.' + module_split[submodule_index - 1]
                                  if parent_module else
                                  module_split[submodule_index - 1])
                import_from += '.' + parent_module
            module_code_builder.add_import(-1, parent_module, import_from,
                                           module_split[submodule_index],
                                           module_split[submodule_index])

    return module_code_builder.build()
Beispiel #51
0
    def __call__(self, path, parent, children):
        # The path to the object.
        lib_path = 'tensorflow.%s' % path if path else 'tensorflow'
        _, parent = tf_decorator.unwrap(parent)

        # A small helper method to construct members(children) protos.
        def _AddMember(member_name, member_obj, proto):
            """Add the child object to the object being constructed."""
            _, member_obj = tf_decorator.unwrap(member_obj)
            if (_SkipMember(parent, member_name) or isinstance(
                    member_obj, deprecation.HiddenTfApiAttribute)):
                return
            if member_name == '__init__' or not member_name.startswith('_'):
                if tf_inspect.isroutine(member_obj):
                    new_method = proto.member_method.add()
                    new_method.name = member_name
                    # If member_obj is a python builtin, there is no way to get its
                    # argspec, because it is implemented on the C side. It also has no
                    # func_code.
                    if hasattr(member_obj, '__code__'):
                        new_method.argspec = _SanitizedArgSpec(member_obj)
                else:
                    new_member = proto.member.add()
                    new_member.name = member_name
                    if tf_inspect.ismodule(member_obj):
                        new_member.mtype = "<type \'module\'>"
                    else:
                        new_member.mtype = _NormalizeType(str(
                            type(member_obj)))

        parent_corner_cases = _CORNER_CASES.get(path, {})

        if path not in _CORNER_CASES or parent_corner_cases:
            # Decide if we have a module or a class.
            if tf_inspect.ismodule(parent):
                # Create a module object.
                module_obj = api_objects_pb2.TFAPIModule()
                for name, child in children:
                    if name in parent_corner_cases:
                        # If we have an empty entry, skip this object.
                        if parent_corner_cases[name]:
                            module_obj.member.add(
                                **(parent_corner_cases[name]))
                    else:
                        _AddMember(name, child, module_obj)

                # Store the constructed module object.
                self._protos[lib_path] = api_objects_pb2.TFAPIObject(
                    path=lib_path, tf_module=module_obj)
            elif _IsProtoClass(parent):
                proto_obj = api_objects_pb2.TFAPIProto()
                parent.DESCRIPTOR.CopyToProto(proto_obj.descriptor)

                # Store the constructed proto object.
                self._protos[lib_path] = api_objects_pb2.TFAPIObject(
                    path=lib_path, tf_proto=proto_obj)
            elif tf_inspect.isclass(parent):
                # Construct a class.
                class_obj = api_objects_pb2.TFAPIClass()
                class_obj.is_instance.extend(
                    _NormalizeIsInstance(i) for i in _SanitizedMRO(parent))
                for name, child in children:
                    if name in parent_corner_cases:
                        # If we have an empty entry, skip this object.
                        if parent_corner_cases[name]:
                            class_obj.member.add(**(parent_corner_cases[name]))
                    else:
                        _AddMember(name, child, class_obj)

                # Store the constructed class object.
                self._protos[lib_path] = api_objects_pb2.TFAPIObject(
                    path=lib_path, tf_class=class_obj)
            else:
                logging.error(
                    'Illegal call to ApiProtoDump::_py_obj_to_proto.'
                    'Object is neither a module nor a class: %s', path)
Beispiel #52
0
def get_api_imports():
    """Get a map from destination module to formatted imports.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: List of strings representing module imports
          (for e.g. 'from foo import bar') and constant
          assignments (for e.g. 'FOO = 123').
  """
    module_imports_builder = _ModuleImportsBuilder()
    visited_symbols = set()

    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.
    for module in sys.modules.values():
        # Only look at tensorflow modules.
        if not module or 'tensorflow.' not in module.__name__:
            continue
        # Do not generate __init__.py files for contrib modules for now.
        if '.contrib.' in module.__name__ or module.__name__.endswith(
                '.contrib'):
            continue

        for module_contents_name in dir(module):
            attr = getattr(module, module_contents_name)
            if id(attr) in visited_symbols:
                continue

            # If attr is _tf_api_constants attribute, then add the constants.
            if module_contents_name == _API_CONSTANTS_ATTR:
                for exports, value in attr:
                    for export in exports:
                        names = export.split('.')
                        dest_module = '.'.join(names[:-1])
                        module_imports_builder.add_import(
                            dest_module, module.__name__, value, names[-1])
                continue

            _, attr = tf_decorator.unwrap(attr)
            # If attr is a symbol with _tf_api_names attribute, then
            # add import for it.
            if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
                # If the same symbol is available using multiple names, only create
                # imports for it once.
                if id(attr) in visited_symbols:
                    continue
                visited_symbols.add(id(attr))

                for export in attr._tf_api_names:  # pylint: disable=protected-access
                    names = export.split('.')
                    dest_module = '.'.join(names[:-1])
                    module_imports_builder.add_import(dest_module,
                                                      module.__name__,
                                                      module_contents_name,
                                                      names[-1])

    # Import all required modules in their parent modules.
    # For e.g. if we import 'foo.bar.Value'. Then, we also
    # import 'bar' in 'foo'.
    imported_modules = set(module_imports_builder.module_imports.keys())
    for module in imported_modules:
        if not module:
            continue
        module_split = module.split('.')
        parent_module = ''  # we import submodules in their parent_module

        for submodule_index in range(len(module_split)):
            import_from = _OUTPUT_MODULE
            if submodule_index > 0:
                parent_module += ('.' + module_split[submodule_index - 1]
                                  if parent_module else
                                  module_split[submodule_index - 1])
                import_from += '.' + parent_module
            module_imports_builder.add_import(parent_module, import_from,
                                              module_split[submodule_index],
                                              module_split[submodule_index])

    return module_imports_builder.module_imports
 def get_argspec_with_decorator(obj):
   decorators, target = tf_decorator.unwrap(obj)
   return next((d.decorator_argspec
                for d in decorators
                if d.decorator_argspec is not None),
               _inspect.getargspec(target))
Beispiel #54
0
  def from_function_and_signature(cls, python_function,
                                  input_signature,
                                  is_pure=False,
                                  experimental_follow_type_hints=False,
                                  jit_compile=None):
    """Creates a FunctionSpec instance given a python function and signature.

    Args:
      python_function: a function to inspect
      input_signature: a signature of the function (None, if variable)
      is_pure: if True all input arguments (including variables and constants)
      will be converted to tensors and no variable changes allowed.
      experimental_follow_type_hints: see `tf.function`
      jit_compile: see `tf.function`

    Returns:
      instance of FunctionSpec
    """
    fullargspec = tf_inspect.getfullargspec(python_function)
    if (input_signature is not None and
        set(fullargspec.kwonlyargs) - set(fullargspec.kwonlydefaults or ())):
      nodefault_kwonlyargs = set(fullargspec.kwonlyargs)
      if fullargspec.kwonlydefaults is not None:
        nodefault_kwonlyargs -= set(fullargspec.kwonlydefaults)
      raise ValueError("Cannot build TF function from "
                       f"{python_function.__name__}: keyword-only arguments "
                       "must have default values when input_signature is "
                       "provided. Got keyword-only arguments without default "
                       f"values: {sorted(nodefault_kwonlyargs)}.")

    # Checks if the `fullargspec` contains self or cls as its first argument.
    is_method = tf_inspect.isanytargetmethod(python_function)

    # Treat a wrapped partial function as a special case. For all arguments that
    # were overridden with keywords in the partial:
    #   - remove the corresponding arguments,
    #   - remove the corresponding keywords.
    _, unwrapped = tf_decorator.unwrap(python_function)
    if isinstance(unwrapped, functools.partial):
      # Also consider the Python3 case with kwonlydefaults.
      if fullargspec.defaults or fullargspec.kwonlydefaults:
        new_defaults = fullargspec.defaults
        new_args = fullargspec.args
        if fullargspec.defaults:
          # To be able to canonicalize the function properly, we want to ignore
          # default values that are overridden via a partial kwarg. For example:
          #
          #   def func(a, b, c, d=5, e=7):
          #     return a, b, c, d, e
          #   p_func = tf.function(functools.partial(func, 10, e=9))
          #
          # Here we want to drop from the defaults the parameter `e`. If we
          # forwarded the call to the partial function with a default for `e`
          # we would get an error for passing two values for one parameter.
          #
          # Note that this has a limitation: we can only override parameters at
          # the end of the parameter list.
          #
          # In this case we want to end up with 3 arguments (b, c, d) and 1
          # default value (5). We do this by constructing a mask where 0 stands
          # for a value that was overridden by a partial kwarg. The seemingly
          # complicated logic below does just that - for arguments (b, c, d, e)
          # we would get a mask (1, 1, 1, 0).
          old_args = fullargspec.args
          old_defaults = fullargspec.defaults

          no_default = object()
          num_args_without_defaults = len(old_args) - len(old_defaults)
          left_padding = tuple([no_default] * num_args_without_defaults)

          args_with_defaults = zip(old_args, left_padding + old_defaults)

          # Create a mask where 0 stands for args that had a partial kwarg
          # defined.
          non_keyword_defaults_mask = [
              0 if key in unwrapped.keywords else 1 for key in old_args
          ]
          # Keep only arguments and defaults that were not kwargs of partial.
          new_args_with_defaults = list(
              itertools.compress(args_with_defaults, non_keyword_defaults_mask))
          # Keep all args.
          new_args = [arg for arg, _ in new_args_with_defaults]
          # Keep only real default values.
          new_defaults = [
              default for _, default in new_args_with_defaults
              if default is not no_default
          ]
        fullargspec = tf_inspect.FullArgSpec(
            args=new_args,
            varargs=fullargspec.varargs,
            varkw=fullargspec.varkw,
            defaults=new_defaults,
            kwonlyargs=[],
            kwonlydefaults={},
            annotations=fullargspec.annotations)

    # Get the function's name.  Remove functools.partial wrappers if necessary.
    while isinstance(python_function, functools.partial):
      python_function = python_function.func
    name = getattr(python_function, "__name__", "f")

    return FunctionSpec(
        fullargspec,
        is_method,
        input_signature,
        is_pure=is_pure,
        jit_compile=jit_compile,
        experimental_follow_type_hints=experimental_follow_type_hints,
        name=name)
Beispiel #55
0
def getsourcelines(object):  # pylint: disable=redefined-builtin
    """TFDecorator-aware replacement for inspect.getsourcelines."""
    return _inspect.getsourcelines(tf_decorator.unwrap(object)[1])
Beispiel #56
0
def isgeneratorfunction(object):  # pylint: disable=redefined-builtin
    """TFDecorator-aware replacement for inspect.isgeneratorfunction."""
    return _inspect.isgeneratorfunction(tf_decorator.unwrap(object)[1])
Beispiel #57
0
def get_api_imports():
    """Get a map from destination module to formatted imports.

  Returns:
    A dictionary where
      key: (string) destination module (for e.g. tf or tf.consts).
      value: List of strings representing module imports
          (for e.g. 'from foo import bar') and constant
          assignments (for e.g. 'FOO = 123').
  """
    module_imports = collections.defaultdict(list)
    # Traverse over everything imported above. Specifically,
    # we want to traverse over TensorFlow Python modules.
    for module in sys.modules.values():
        # Only look at tensorflow modules.
        if not module or 'tensorflow.' not in module.__name__:
            continue
        # Do not generate __init__.py files for contrib modules for now.
        if '.contrib.' in module.__name__ or module.__name__.endswith(
                '.contrib'):
            continue

        for module_contents_name in dir(module):
            attr = getattr(module, module_contents_name)

            # If attr is _tf_api_constants attribute, then add the constants.
            if module_contents_name == _API_CONSTANTS_ATTR:
                for exports, value in attr:
                    for export in exports:
                        names = export.split('.')
                        dest_module = '.'.join(names[:-1])
                        import_str = format_import(module.__name__, value,
                                                   names[-1])
                        module_imports[dest_module].append(import_str)
                continue

            _, attr = tf_decorator.unwrap(attr)
            # If attr is a symbol with _tf_api_names attribute, then
            # add import for it.
            if hasattr(attr, '__dict__') and _API_NAMES_ATTR in attr.__dict__:
                # The same op might be accessible from multiple modules.
                # We only want to consider location where function was defined.
                # Here we check if the op is defined in another TensorFlow module in
                # sys.modules.
                if (hasattr(attr, '__module__')
                        and attr.__module__.startswith(tf.__name__)
                        and attr.__module__ != module.__name__
                        and attr.__module__ in sys.modules
                        and module_contents_name in dir(
                            sys.modules[attr.__module__])):
                    continue

                for export in attr._tf_api_names:  # pylint: disable=protected-access
                    names = export.split('.')
                    dest_module = '.'.join(names[:-1])
                    import_str = format_import(module.__name__,
                                               module_contents_name, names[-1])
                    module_imports[dest_module].append(import_str)

    # Import all required modules in their parent modules.
    # For e.g. if we import 'foo.bar.Value'. Then, we also
    # import 'bar' in 'foo'.
    imported_modules = set(module_imports.keys())
    for module in imported_modules:
        if not module:
            continue
        module_split = module.split('.')
        parent_module = ''  # we import submodules in their parent_module

        for submodule_index in range(len(module_split)):
            import_from = _OUTPUT_MODULE
            if submodule_index > 0:
                parent_module += ('.' + module_split[submodule_index - 1]
                                  if parent_module else
                                  module_split[submodule_index - 1])
                import_from += '.' + parent_module
            submodule_import = format_import(import_from,
                                             module_split[submodule_index],
                                             module_split[submodule_index])
            if submodule_import not in module_imports[parent_module]:
                module_imports[parent_module].append(submodule_import)

    return module_imports
Beispiel #58
0
def _is_bounded_method(fn):
    _, fn = tf_decorator.unwrap(fn)
    return tf_inspect.ismethod(fn) and (fn.__self__ is not None)
Beispiel #59
0
 def testUnwrapReturnsUndecoratedFunctionAsTarget(self):
     _, target = tf_decorator.unwrap(test_function)
     self.assertIs(test_function, target)
Beispiel #60
0
 def testUnwrapReturnsEmptyArrayForUndecoratedFunction(self):
     decorators, _ = tf_decorator.unwrap(test_function)
     self.assertEqual(0, len(decorators))