示例#1
0
 def testWrapperHasAllPublicMethodsOfSession(self):
   session_public_methods = [
       method_tuple[0] for method_tuple in
       tf_inspect.getmembers(session.Session, predicate=tf_inspect.ismethod)
       if _is_public_method_name(method_tuple[0])]
   wrapper_public_methods = [
       method_tuple[0] for method_tuple in
       tf_inspect.getmembers(
           framework.BaseDebugWrapperSession, predicate=tf_inspect.ismethod)
       if _is_public_method_name(method_tuple[0])]
   missing_public_methods = [
       method for method in session_public_methods
       if method not in wrapper_public_methods]
   self.assertFalse(missing_public_methods)
示例#2
0
def _traverse_internal(root, visit, stack, path):
  """Internal helper for traverse."""

  # Only traverse modules and classes
  if not tf_inspect.isclass(root) and not tf_inspect.ismodule(root):
    return

  try:
    children = tf_inspect.getmembers(root)

    # Add labels for duplicate values in Enum.
    if tf_inspect.isclass(root) and issubclass(root, enum.Enum):
      for enum_member in root.__members__.items():
        if enum_member not in children:
          children.append(enum_member)
      children = sorted(children)
  except ImportError:
    # On some Python installations, some modules do not support enumerating
    # members (six in particular), leading to import errors.
    children = []

  new_stack = stack + [root]
  visit(path, root, children)
  for name, child in children:
    # Do not descend into built-in modules
    if tf_inspect.ismodule(
        child) and child.__name__ in sys.builtin_module_names:
      continue

    # Break cycles
    if any(child is item for item in new_stack):  # `in`, but using `is`
      continue

    child_path = path + '.' + name if path else name
    _traverse_internal(child, visit, new_stack, child_path)
示例#3
0
 def parameters(self):
   """A dict of names to values of properties marked with `@parameter`."""
   property_param_names = [name
                           for name, func in tf_inspect.getmembers(type(self))
                           if (hasattr(func, "fget") and hasattr(
                               getattr(func, "fget"), "is_parameter"))]
   return {name: getattr(self, name) for name in property_param_names}
 def parameters(self):
   """A dict of names to values of properties marked with `@parameter`."""
   property_param_names = [name
                           for name, func in tf_inspect.getmembers(type(self))
                           if (hasattr(func, "fget") and hasattr(
                               getattr(func, "fget"), "is_parameter"))]
   return {name: getattr(self, name) for name in property_param_names}
示例#5
0
def class_to_graph(c, conversion_map):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = None
    for _, m in members:
        node, _ = function_to_graph(m,
                                    conversion_map=conversion_map,
                                    arg_values={},
                                    arg_types={'self': (c.__name__, c)},
                                    owner_type=c)
        # TODO(mdan): Do not assume all members have the same view of globals.
        if class_namespace is None:
            class_namespace = inspect_utils.getnamespace(m)
        converted_members[m] = node
    namer = conversion_map.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)
    node = gast.ClassDef(class_name,
                         bases=[],
                         keywords=[],
                         body=list(converted_members.values()),
                         decorator_list=[])

    return node, class_name
示例#6
0
def class_to_graph(c, conversion_map):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    node, _, namespace = function_to_graph(
        m,
        conversion_map=conversion_map,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])

  return node, class_name, class_namespace
示例#7
0
def class_to_graph(c, conversion_map, param_value_hints):
  """Specialization of `object_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  if 'self' in param_value_hints:
    raise ValueError('Hints may not be provided for reserved name "self".')
  param_value_hints['self'] = (c.__name__, c)

  class_globals = None
  for _, m in members:
    node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_globals is None:
      class_globals = six.get_function_globals(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_globals)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
示例#8
0
def class_to_graph(c, conversion_map):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  class_globals = None
  for _, m in members:
    node, _ = function_to_graph(
        m,
        conversion_map=conversion_map,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_globals is None:
      class_globals = six.get_function_globals(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_globals)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
示例#9
0
def class_to_graph(c, conversion_map, param_value_hints):
  """Specialization of `object_to_graph` for classes."""
  converted_members = {}
  members = tf_inspect.getmembers(c, predicate=tf_inspect.ismethod)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.')

  if 'self' in param_value_hints:
    raise ValueError('Hints may not be provided for reserved name "self".')
  param_value_hints['self'] = (c.__name__, c)

  class_globals = None
  for _, m in members:
    node, _ = function_to_graph(m, conversion_map, param_value_hints, c)
    # TODO(mdan): Do not assume all members have the same view of globals.
    if class_globals is None:
      class_globals = six.get_function_globals(m)
    converted_members[m] = node
  namer = conversion_map.new_namer(class_globals)
  class_name = namer.compiled_class_name(c.__name__, c)
  node = gast.ClassDef(
      class_name,
      bases=[],
      keywords=[],
      body=converted_members.values(),
      decorator_list=[])

  return node, class_name
示例#10
0
def make_all(module_name, doc_string_modules=None):
  """Generates `__all__` from the docstring of one or more modules.

  Usage: `make_all(__name__)` or
  `make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
  modules must each a docstring, and `__all__` will contain all symbols with
  `@@` references, where that symbol currently exists in the module named
  `module_name`.

  Args:
    module_name: The name of the module (usually `__name__`).
    doc_string_modules: a list of modules from which to take docstring.
    If None, then a list containing only the module named `module_name` is used.

  Returns:
    A list suitable for use as `__all__`.
  """
  if doc_string_modules is None:
    doc_string_modules = [_sys.modules[module_name]]
  cur_members = set([name for name, _
                     in _tf_inspect.getmembers(_sys.modules[module_name])])

  results = set()
  for doc_module in doc_string_modules:
    results.update([m.group(1)
                    for m in _reference_pattern.finditer(doc_module.__doc__)
                    if m.group(1) in cur_members])
  return list(results)
示例#11
0
def make_all(module_name, doc_string_modules=None):
    """Generates `__all__` from the docstring of one or more modules.

  Usage: `make_all(__name__)` or
  `make_all(__name__, [sys.modules(__name__), other_module])`. The doc string
  modules must each a docstring, and `__all__` will contain all symbols with
  `@@` references, where that symbol currently exists in the module named
  `module_name`.

  Args:
    module_name: The name of the module (usually `__name__`).
    doc_string_modules: a list of modules from which to take docstring.
    If None, then a list containing only the module named `module_name` is used.

  Returns:
    A list suitable for use as `__all__`.
  """
    if doc_string_modules is None:
        doc_string_modules = [_sys.modules[module_name]]
    cur_members = set([
        name for name, _ in _tf_inspect.getmembers(_sys.modules[module_name])
    ])

    results = set()
    for doc_module in doc_string_modules:
        results.update([
            m.group(1) for m in _reference_pattern.finditer(doc_module.__doc__)
            if m.group(1) in cur_members
        ])
    return list(results)
示例#12
0
 def testWrapperHasAllPublicMethodsOfSession(self):
     session_public_methods = [
         method_tuple[0] for method_tuple in tf_inspect.getmembers(
             session.Session, predicate=tf_inspect.ismethod)
         if _is_public_method_name(method_tuple[0])
     ]
     wrapper_public_methods = [
         method_tuple[0] for method_tuple in tf_inspect.getmembers(
             framework.BaseDebugWrapperSession,
             predicate=tf_inspect.ismethod)
         if _is_public_method_name(method_tuple[0])
     ]
     missing_public_methods = [
         method for method in session_public_methods
         if method not in wrapper_public_methods
     ]
     self.assertFalse(missing_public_methods)
示例#13
0
def gen_register_op(source, method_prefix=None):
    """Parse a python code and emit the TFR functions from a target class."""
    mlir_funcs = [
        op_reg_gen(func)
        for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
        if not method_prefix or name.startswith(method_prefix)
    ]
    headers = r"""
#include "tensorflow/core/framework/op.h"

namespace tensorflow {
  """
    code = '\n'.join(mlir_funcs)
    return headers + code + '}  // namespace tensorflow\n'
def assert_estimator_contract(tester, estimator_class):
  """Asserts whether given estimator satisfies the expected contract.

  This doesn't check every details of contract. This test is used for that a
  function is not forgotten to implement in a precanned Estimator.

  Args:
    tester: A tf.test.TestCase.
    estimator_class: 'type' object of pre-canned estimator.
  """
  attributes = tf_inspect.getmembers(estimator_class)
  attribute_names = [a[0] for a in attributes]

  tester.assertTrue('config' in attribute_names)
  tester.assertTrue('evaluate' in attribute_names)
  tester.assertTrue('export' in attribute_names)
  tester.assertTrue('fit' in attribute_names)
  tester.assertTrue('get_variable_names' in attribute_names)
  tester.assertTrue('get_variable_value' in attribute_names)
  tester.assertTrue('model_dir' in attribute_names)
  tester.assertTrue('predict' in attribute_names)
示例#15
0
def assert_estimator_contract(tester, estimator_class):
    """Asserts whether given estimator satisfies the expected contract.

  This doesn't check every details of contract. This test is used for that a
  function is not forgotten to implement in a precanned Estimator.

  Args:
    tester: A tf.test.TestCase.
    estimator_class: 'type' object of pre-canned estimator.
  """
    attributes = tf_inspect.getmembers(estimator_class)
    attribute_names = [a[0] for a in attributes]

    tester.assertTrue('config' in attribute_names)
    tester.assertTrue('evaluate' in attribute_names)
    tester.assertTrue('export' in attribute_names)
    tester.assertTrue('fit' in attribute_names)
    tester.assertTrue('get_variable_names' in attribute_names)
    tester.assertTrue('get_variable_value' in attribute_names)
    tester.assertTrue('model_dir' in attribute_names)
    tester.assertTrue('predict' in attribute_names)
  def test_nccl_ops(self):
    """Tests behavior of nccl ops when NCCL is not installed."""

    public_methods = [
        m[0]
        for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction)
        if not m[0].startswith('_')
    ]
    for method_name in public_methods:
      with ops.device('/device:CPU:0'):
        tensor = constant_op.constant(1)

      if method_name == 'broadcast':
        arg = tensor
      else:
        arg = [tensor]

      nccl_op = getattr(nccl, method_name)
      with ops.device('/device:CPU:0'):
        with self.assertRaisesRegexp(errors_impl.NotFoundError,
                                     r'cannot open shared object file'):
          nccl_op(arg)
    def test_nccl_ops(self):
        """Tests behavior of nccl ops when NCCL is not installed."""

        public_methods = [
            m[0] for m in tf_inspect.getmembers(nccl, tf_inspect.isfunction)
            if not m[0].startswith('_')
        ]
        for method_name in public_methods:
            with ops.device('/device:CPU:0'):
                tensor = constant_op.constant(1)

            if method_name == 'broadcast':
                arg = tensor
            else:
                arg = [tensor]

            nccl_op = getattr(nccl, method_name)
            with ops.device('/device:CPU:0'):
                with self.assertRaisesRegexp(
                        errors_impl.NotFoundError,
                        r'cannot open shared object file'):
                    nccl_op(arg)
示例#18
0
def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
  """Parse a python code and emit the TFR functions from a target class."""
  op_defs = OpDefCache()

  if op_libraries:
    for m in op_libraries:
      lib_dir = os.path.dirname(m.__file__)
      prefix_len = len('gen_')
      lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so')
      # Load the op library so the op is added to the op registry. This is
      # required when the op cc_library couldn't be statically linked in open
      # source.
      # This is a no op if the op shared library couldn't be found in the same
      # directory of the op Python API.
      load_library.load_op_library(os.path.join(lib_dir, lib_name))

  mlir_funcs = [
      tfr_gen(func, op_defs)
      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
      if not method_prefix or name.startswith(method_prefix)
  ]

  return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
示例#19
0
def tfr_gen_from_module(source, method_prefix=None, op_libraries=None):
  """Parse the input source module and emit the TFR functions."""
  op_defs = OpDefCache()

  # Load the op library so the op is added to the op registry. This is
  # required when the op cc_library couldn't be statically linked in open
  # source.
  # This is a no op if the op shared library couldn't be found in the same
  # directory of the op Python API.
  # TODO(fengliuai): make the .so file path configurable.
  if op_libraries:
    prefix_len = len('gen_')
    for m in op_libraries:
      lib_dir = os.path.dirname(m.__file__)
      lib_name = os.path.basename(m.__file__)[prefix_len:].replace('.py', '.so')
      lib_path = os.path.join(lib_dir, lib_name)
      if os.path.exists(lib_path):
        logging.info('load file: ' + lib_path)
        load_library.load_op_library(lib_path)
  else:
    # The op library is generated from the source module, then we load all the
    # .so file in the directory
    lib_dir = os.path.dirname(source.__file__)
    for lib_name in os.listdir(lib_dir):
      if lib_name.endswith('.so'):
        lib_path = os.path.join(lib_dir, lib_name)
        logging.info('load file: ' + lib_path)
        load_library.load_op_library(lib_path)

  mlir_funcs = [
      tfr_gen(func, op_defs)
      for name, func in tf_inspect.getmembers(source, tf_inspect.isfunction)
      if not method_prefix or name.startswith(method_prefix)
  ]

  return '\n'.join(mlir_funcs + op_defs.mlir_external_funcs())
示例#20
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        node, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            owner_type=c)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = node[0]
    namer = program_ctx.new_namer(class_namespace)
    class_name = namer.compiled_class_name(c.__name__, c)

    # TODO(mdan): This needs to be explained more thoroughly.
    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated. Otherwise, it is marked for conversion
    # (as a side effect of the call to namer.compiled_class_name() followed by
    # program_ctx.update_name_map(namer)).
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            # This will trigger a conversion into a class with this name.
            alias = namer.compiled_class_name(base.__name__, base)
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
    program_ctx.update_name_map(namer)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
示例#21
0
from tensorflow_docs.api_generator import doc_generator_visitor
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import parser

import tensorboard
import tensorflow_estimator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect

# Use tensorflow's `tf_inspect`, which is aware of `tf_decorator`.
parser.tf_inspect = tf_inspect

# `tf` has an `__all__` that doesn't list important things like `keras`.
# The doc generator recognizes `__all__` as the list of public symbols.
# So patch `tf.__all__` to list everything.
tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)]

FLAGS = flags.FLAGS

flags.DEFINE_string(
    "code_url_prefix", "/code/stable/tensorflow",
    "A url to prepend to code paths when creating links to defining code")

flags.DEFINE_string("output_dir", "/tmp/out",
                    "A directory, where the docs will be output to.")

flags.DEFINE_bool("search_hints", True,
                  "Include meta-data search hints at the top of each file.")

flags.DEFINE_string(
    "site_path", "",
示例#22
0
def class_to_graph(c, program_ctx):
  """Specialization of `entity_to_graph` for classes."""
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('Cannot convert %s: it has no member methods.' % c)

  class_namespace = {}
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    node, _, namespace = function_to_graph(
        m,
        program_ctx=program_ctx,
        arg_values={},
        arg_types={'self': (c.__name__, c)},
        owner_type=c)
    if class_namespace is None:
      class_namespace = namespace
    else:
      class_namespace.update(namespace)
    converted_members[m] = node[0]
  namer = program_ctx.new_namer(class_namespace)
  class_name = namer.compiled_class_name(c.__name__, c)

  # TODO(mdan): This needs to be explained more thoroughly.
  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated. Otherwise, it is marked for conversion
  # (as a side effect of the call to namer.compiled_class_name() followed by
  # program_ctx.update_name_map(namer)).
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      # This will trigger a conversion into a class with this name.
      alias = namer.compiled_class_name(base.__name__, base)
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)
  program_ctx.update_name_map(namer)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  return output_nodes, class_name, class_namespace
示例#23
0
def convert_class_to_ast(c, program_ctx):
    """Specialization of `convert_entity_to_ast` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('cannot convert %s: no member methods' % c)

    # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
    # The assumption that one namespace suffices for all methods only holds if
    # all methods were defined in the same module.
    # If, instead, functions are imported from multiple modules and then spliced
    # into the class, then each function has its own globals and __future__
    # imports that need to stay separate.

    # For example, C's methods could both have `global x` statements referring to
    # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
    # from mod1 import f1
    # from mod2 import f2
    # class C(object):
    #   method1 = f1
    #   method2 = f2

    class_namespace = {}
    future_features = None
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        (node, ), _, entity_info = convert_func_to_ast(m,
                                                       program_ctx=program_ctx,
                                                       do_rename=False)
        class_namespace.update(entity_info.namespace)
        converted_members[m] = node

        # TODO(mdan): Similarly check the globals.
        if future_features is None:
            future_features = entity_info.future_features
        elif frozenset(future_features) ^ frozenset(
                entity_info.future_features):
            # Note: we can support this case if ever needed.
            raise ValueError(
                'cannot convert {}: if has methods built with mismatched future'
                ' features: {} and {}'.format(c, future_features,
                                              entity_info.future_features))
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    # TODO(mdan): Find a way better than forging this object.
    entity_info = transformer.EntityInfo(source_code=None,
                                         source_file=None,
                                         future_features=future_features,
                                         namespace=class_namespace)

    return output_nodes, class_name, entity_info
示例#24
0
def build_docs(output_dir, code_url_prefix, search_hints):
    """Build api docs for tensorflow v2.

  Args:
    output_dir: A string path, where to put the files.
    code_url_prefix: prefix for "Defined in" links.
    search_hints: Bool. Include meta-data search hints at the top of each file.
  """
    if distutils.version.LooseVersion(tf.__version__) >= "2.9":
        doc_controls.set_deprecated(tf.keras.preprocessing)

    # The custom page will be used for raw_ops.md not the one generated above.
    doc_controls.set_custom_page_builder_cls(tf.raw_ops, RawOpsPageInfo)

    # Hide raw_ops from search.
    for name, obj in tf_inspect.getmembers(tf.raw_ops):
        if not name.startswith("_"):
            doc_controls.hide_from_search(obj)

    for cls in [
            tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer
    ]:
        doc_controls.decorate_all_class_attributes(
            decorator=doc_controls.do_not_doc_in_subclasses,
            cls=cls,
            skip=["__init__"])

    do_not_document = [
        "tf.__internal__", "tf.keras.__internal__", "tf.__operators__",
        "tf.tools", "tf.compat.v1.pywrap_tensorflow", "tf.pywrap_tensorflow",
        "tf.flags", "tf.batch_mat_mul_v3", "tf.sparse_segment_sum_grad"
    ]
    for path in do_not_document:
        item = tf
        for part in path.split(".")[1:]:
            item = getattr(item, part, None)
        if item is None:
            continue
        doc_controls.do_not_generate_docs(item)

    base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes(
        code_url_prefix)
    doc_generator = generate_lib.DocGenerator(
        root_title="TensorFlow 2",
        py_modules=[("tf", tf)],
        base_dir=base_dirs,
        search_hints=search_hints,
        code_url_prefix=code_url_prefixes,
        site_path=FLAGS.site_path,
        visitor_cls=TfExportAwareVisitor,
        private_map=_PRIVATE_MAP,
        extra_docs=_EXTRA_DOCS)

    doc_generator.build(output_dir)

    out_path = pathlib.Path(output_dir)

    expected_path_contents = {
        "tf/summary/audio.md": "tensorboard/plugins/audio/summary_v2.py",
        "tf/estimator/DNNClassifier.md":
        "tensorflow_estimator/python/estimator/canned/dnn.py",
        "tf/nn/sigmoid_cross_entropy_with_logits.md": "python/ops/nn_impl.py",
        "tf/keras/Model.md": "keras/engine/training.py",
    }

    all_passed = True
    error_msg_parts = [
        'Some "view source" links seem to be broken, please check:'
    ]

    for (rel_path, contents) in expected_path_contents.items():
        path = out_path / rel_path
        if contents not in path.read_text():
            all_passed = False
            error_msg_parts.append("  " + str(path))

    if not all_passed:
        raise ValueError("\n".join(error_msg_parts))

    rejected_path_contents = {
        "tf/keras/optimizers.md": "keras/optimizers/__init__.py",
    }

    all_passed = True
    error_msg_parts = [
        'Bad "view source" links in generated files, please check:'
    ]
    for rel_path, content in rejected_path_contents.items():
        path = out_path / rel_path
        if content in path.read_text():
            all_passed = False
            error_msg_parts.append("  " + str(path))

    if not all_passed:
        raise ValueError("\n".join(error_msg_parts))

    num_files = len(list(out_path.rglob("*")))
    if num_files < MIN_NUM_FILES_EXPECTED:
        raise ValueError(
            f"The TensorFlow api should be more than {MIN_NUM_FILES_EXPECTED} files"
            f"(found {num_files}).")
示例#25
0
def class_to_graph(c, program_ctx):
    """Specialization of `entity_to_graph` for classes."""
    # TODO(mdan): Revisit this altogether. Not sure we still need it.
    converted_members = {}
    method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m
                                                                              )
    members = tf_inspect.getmembers(c, predicate=method_filter)
    if not members:
        raise ValueError('Cannot convert %s: it has no member methods.' % c)

    class_namespace = {}
    for _, m in members:
        # Only convert the members that are directly defined by the class.
        if inspect_utils.getdefiningclass(m, c) is not c:
            continue
        nodes, _, namespace = function_to_graph(
            m,
            program_ctx=program_ctx,
            arg_values={},
            arg_types={'self': (c.__name__, c)},
            do_rename=False)
        if class_namespace is None:
            class_namespace = namespace
        else:
            class_namespace.update(namespace)
        converted_members[m] = nodes[0]
    namer = naming.Namer(class_namespace)
    class_name = namer.class_name(c.__name__)

    # Process any base classes: if the superclass if of a whitelisted type, an
    # absolute import line is generated.
    output_nodes = []
    renames = {}
    base_names = []
    for base in c.__bases__:
        if isinstance(object, base):
            base_names.append('object')
            continue
        if is_whitelisted_for_graph(base):
            alias = namer.new_symbol(base.__name__, ())
            output_nodes.append(
                gast.ImportFrom(
                    module=base.__module__,
                    names=[gast.alias(name=base.__name__, asname=alias)],
                    level=0))
        else:
            raise NotImplementedError(
                'Conversion of classes that do not directly extend classes from'
                ' whitelisted modules is temporarily suspended. If this breaks'
                ' existing code please notify the AutoGraph team immediately.')
        base_names.append(alias)
        renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

    # Generate the definition of the converted class.
    bases = [gast.Name(n, gast.Load(), None) for n in base_names]
    class_def = gast.ClassDef(class_name,
                              bases=bases,
                              keywords=[],
                              body=list(converted_members.values()),
                              decorator_list=[])
    # Make a final pass to replace references to the class or its base classes.
    # Most commonly, this occurs when making super().__init__() calls.
    # TODO(mdan): Making direct references to superclass' superclass will fail.
    class_def = qual_names.resolve(class_def)
    renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
    class_def = ast_util.rename_symbols(class_def, renames)

    output_nodes.append(class_def)

    return output_nodes, class_name, class_namespace
示例#26
0
def build_docs(output_dir, code_url_prefix, search_hints=True):
  """Build api docs for tensorflow v2.

  Args:
    output_dir: A string path, where to put the files.
    code_url_prefix: prefix for "Defined in" links.
    search_hints: Bool. Include meta-data search hints at the top of each file.
  """
  # The custom page will be used for raw_ops.md not the one generated above.
  doc_controls.set_custom_page_content(tf.raw_ops, generate_raw_ops_doc())

  # Hide raw_ops from search.
  for name, obj in tf_inspect.getmembers(tf.raw_ops):
    if not name.startswith("_"):
      doc_controls.hide_from_search(obj)

  _hide_layer_and_module_methods()

  try:
    doc_controls.do_not_generate_docs(tf.__operators__)
  except AttributeError:
    pass

  try:
    doc_controls.do_not_generate_docs(tf.tools)
  except AttributeError:
    pass

  try:
    doc_controls.do_not_generate_docs(tf.compat.v1.pywrap_tensorflow)
  except AttributeError:
    pass

  try:
    doc_controls.do_not_generate_docs(tf.pywrap_tensorflow)
  except AttributeError:
    pass

  try:
    doc_controls.do_not_generate_docs(tf.flags)
  except AttributeError:
    pass

  base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes(
      code_url_prefix)
  doc_generator = generate_lib.DocGenerator(
      root_title="TensorFlow 2",
      py_modules=[("tf", tf)],
      base_dir=base_dirs,
      search_hints=search_hints,
      code_url_prefix=code_url_prefixes,
      site_path=FLAGS.site_path,
      visitor_cls=TfExportAwareVisitor,
      private_map=_PRIVATE_MAP)

  doc_generator.build(output_dir)

  out_path = pathlib.Path(output_dir)
  num_files = len(list(out_path.rglob("*")))
  if num_files < 2000:
    raise ValueError("The TensorFlow api should be more than 2500 files"
                     "(found {}).".format(num_files))
  expected_path_contents = {
      "tf/summary/audio.md":
          "tensorboard/plugins/audio/summary_v2.py",
      "tf/estimator/DNNClassifier.md":
          "tensorflow_estimator/python/estimator/canned/dnn.py",
      "tf/nn/sigmoid_cross_entropy_with_logits.md":
          "python/ops/nn_impl.py",
      "tf/keras/Model.md":
          "tensorflow/python/keras/engine/training.py",
      "tf/compat/v1/gradients.md":
          "tensorflow/python/ops/gradients_impl.py",
  }

  all_passed = True
  error_msg_parts = [
      'Some "view source" links seem to be broken, please check:'
  ]

  for (rel_path, contents) in expected_path_contents.items():
    path = out_path / rel_path
    if contents not in path.read_text():
      all_passed = False
      error_msg_parts.append("  " + str(path))

  if not all_passed:
    raise ValueError("\n".join(error_msg_parts))
示例#27
0
from tensorflow_docs.api_generator import doc_generator_visitor
from tensorflow_docs.api_generator import generate_lib
from tensorflow_docs.api_generator import parser

import tensorboard
import tensorflow_estimator
from tensorflow.python.util import tf_export
from tensorflow.python.util import tf_inspect

# Use tensorflow's `tf_inspect`, which is aware of `tf_decorator`.
parser.tf_inspect = tf_inspect

# `tf` has an `__all__` that doesn't list important things like `keras`.
# The doc generator recognizes `__all__` as the list of public symbols.
# So patch `tf.__all__` to list everything.
tf.__all__ = [item_name for item_name, value in tf_inspect.getmembers(tf)]


FLAGS = flags.FLAGS

flags.DEFINE_string(
    "code_url_prefix",
    "/code/stable/tensorflow",
    "A url to prepend to code paths when creating links to defining code")

flags.DEFINE_string(
    "output_dir", "/tmp/out",
    "A directory, where the docs will be output to.")

flags.DEFINE_bool("search_hints", True,
                  "Include meta-data search hints at the top of each file.")
示例#28
0
 def testGetMembers(self):
   self.assertEqual(
       inspect.getmembers(TestDecoratedClass),
       tf_inspect.getmembers(TestDecoratedClass))
示例#29
0
def build_docs(output_dir, code_url_prefix, search_hints, gen_report):
    """Build api docs for tensorflow v2.

  Args:
    output_dir: A string path, where to put the files.
    code_url_prefix: prefix for "Defined in" links.
    search_hints: Bool. Include meta-data search hints at the top of each file.
    gen_report: Bool. Generates an API report containing the health of the
      docstrings of the public API.
  """
    # The custom page will be used for raw_ops.md not the one generated above.
    doc_controls.set_custom_page_content(tf.raw_ops, generate_raw_ops_doc())

    # Hide raw_ops from search.
    for name, obj in tf_inspect.getmembers(tf.raw_ops):
        if not name.startswith("_"):
            doc_controls.hide_from_search(obj)

    for cls in [
            tf.Module, tf.keras.layers.Layer, tf.keras.optimizers.Optimizer
    ]:
        doc_controls.decorate_all_class_attributes(
            decorator=doc_controls.do_not_doc_in_subclasses,
            cls=cls,
            skip=["__init__"])

    try:
        doc_controls.do_not_generate_docs(tf.__internal__)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.keras.__internal__)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.__operators__)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.tools)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.compat.v1.pywrap_tensorflow)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.pywrap_tensorflow)
    except AttributeError:
        pass

    try:
        doc_controls.do_not_generate_docs(tf.flags)
    except AttributeError:
        pass

    base_dirs, code_url_prefixes = base_dir.get_base_dirs_and_prefixes(
        code_url_prefix)
    doc_generator = generate_lib.DocGenerator(
        root_title="TensorFlow 2",
        py_modules=[("tf", tf)],
        base_dir=base_dirs,
        search_hints=search_hints,
        code_url_prefix=code_url_prefixes,
        site_path=FLAGS.site_path,
        visitor_cls=TfExportAwareVisitor,
        private_map=_PRIVATE_MAP,
        gen_report=gen_report,
        extra_docs=_EXTRA_DOCS)

    doc_generator.build(output_dir)

    if gen_report:
        return

    out_path = pathlib.Path(output_dir)

    expected_path_contents = {
        "tf/summary/audio.md":
        "tensorboard/plugins/audio/summary_v2.py",
        "tf/estimator/DNNClassifier.md":
        "tensorflow_estimator/python/estimator/canned/dnn.py",
        "tf/nn/sigmoid_cross_entropy_with_logits.md":
        "python/ops/nn_impl.py",
        "tf/keras/Model.md":
        "keras/engine/training.py",
        "tf/keras/preprocessing/image/random_brightness.md":
        "keras_preprocessing/image/affine_transformations.py"
    }

    all_passed = True
    error_msg_parts = [
        'Some "view source" links seem to be broken, please check:'
    ]

    for (rel_path, contents) in expected_path_contents.items():
        path = out_path / rel_path
        if contents not in path.read_text():
            all_passed = False
            error_msg_parts.append("  " + str(path))

    if not all_passed:
        raise ValueError("\n".join(error_msg_parts))

    rejected_path_contents = {
        "tf/keras/optimizers.md": "keras/optimizers/__init__.py",
    }

    all_passed = True
    error_msg_parts = [
        'Bad "view source" links in generated files, please check:'
    ]
    for rel_path, content in rejected_path_contents.items():
        path = out_path / rel_path
        if content in path.read_text():
            all_passed = False
            error_msg_parts.append("  " + str(path))

    if not all_passed:
        raise ValueError("\n".join(error_msg_parts))

    num_files = len(list(out_path.rglob("*")))
    if num_files < 2000:
        raise ValueError("The TensorFlow api should be more than 2000 files"
                         "(found {}).".format(num_files))
示例#30
0
def convert_class_to_ast(c, program_ctx):
  """Specialization of `convert_entity_to_ast` for classes."""
  # TODO(mdan): Revisit this altogether. Not sure we still need it.
  converted_members = {}
  method_filter = lambda m: tf_inspect.isfunction(m) or tf_inspect.ismethod(m)
  members = tf_inspect.getmembers(c, predicate=method_filter)
  if not members:
    raise ValueError('cannot convert %s: no member methods' % c)

  # TODO(mdan): Don't clobber namespaces for each method in one class namespace.
  # The assumption that one namespace suffices for all methods only holds if
  # all methods were defined in the same module.
  # If, instead, functions are imported from multiple modules and then spliced
  # into the class, then each function has its own globals and __future__
  # imports that need to stay separate.

  # For example, C's methods could both have `global x` statements referring to
  # mod1.x and mod2.x, but using one namespace for C would cause a conflict.
  # from mod1 import f1
  # from mod2 import f2
  # class C(object):
  #   method1 = f1
  #   method2 = f2

  class_namespace = {}
  future_features = None
  for _, m in members:
    # Only convert the members that are directly defined by the class.
    if inspect_utils.getdefiningclass(m, c) is not c:
      continue
    (node,), _, entity_info = convert_func_to_ast(
        m,
        program_ctx=program_ctx,
        do_rename=False)
    class_namespace.update(entity_info.namespace)
    converted_members[m] = node

    # TODO(mdan): Similarly check the globals.
    if future_features is None:
      future_features = entity_info.future_features
    elif frozenset(future_features) ^ frozenset(entity_info.future_features):
      # Note: we can support this case if ever needed.
      raise ValueError(
          'cannot convert {}: if has methods built with mismatched future'
          ' features: {} and {}'.format(c, future_features,
                                        entity_info.future_features))
  namer = naming.Namer(class_namespace)
  class_name = namer.class_name(c.__name__)

  # Process any base classes: if the superclass if of a whitelisted type, an
  # absolute import line is generated.
  output_nodes = []
  renames = {}
  base_names = []
  for base in c.__bases__:
    if isinstance(object, base):
      base_names.append('object')
      continue
    if is_whitelisted_for_graph(base):
      alias = namer.new_symbol(base.__name__, ())
      output_nodes.append(
          gast.ImportFrom(
              module=base.__module__,
              names=[gast.alias(name=base.__name__, asname=alias)],
              level=0))
    else:
      raise NotImplementedError(
          'Conversion of classes that do not directly extend classes from'
          ' whitelisted modules is temporarily suspended. If this breaks'
          ' existing code please notify the AutoGraph team immediately.')
    base_names.append(alias)
    renames[qual_names.QN(base.__name__)] = qual_names.QN(alias)

  # Generate the definition of the converted class.
  bases = [gast.Name(n, gast.Load(), None) for n in base_names]
  class_def = gast.ClassDef(
      class_name,
      bases=bases,
      keywords=[],
      body=list(converted_members.values()),
      decorator_list=[])
  # Make a final pass to replace references to the class or its base classes.
  # Most commonly, this occurs when making super().__init__() calls.
  # TODO(mdan): Making direct references to superclass' superclass will fail.
  class_def = qual_names.resolve(class_def)
  renames[qual_names.QN(c.__name__)] = qual_names.QN(class_name)
  class_def = ast_util.rename_symbols(class_def, renames)

  output_nodes.append(class_def)

  # TODO(mdan): Find a way better than forging this object.
  entity_info = transformer.EntityInfo(
      source_code=None,
      source_file=None,
      future_features=future_features,
      namespace=class_namespace)

  return output_nodes, class_name, entity_info
示例#31
0
 def testGetMembers(self):
     self.assertEqual(inspect.getmembers(TestDecoratedClass),
                      tf_inspect.getmembers(TestDecoratedClass))