示例#1
0
  def __init__(self, cls_ref, method_name, **kwargs):
    self.cls_ref = cls_ref
    self.method_name = method_name
    self.cls_symbol = (
        get_canonical_name_for_symbol(
            self.cls_ref, add_prefix_to_v1_names=True) or
        get_canonical_name_for_symbol(
            self.cls_ref, api_name='keras', add_prefix_to_v1_names=True))
    if 'name' not in kwargs:
      kwargs['name'] = backend.unique_object_name(
          'tf.' + self.cls_symbol + '.' + self.method_name,
          zero_based=True,
          avoid_observed_names=True)
    kwargs['autocast'] = False

    # Do not individually trace op layers in the SavedModel.
    self._must_restore_from_config = True

    super(ClassMethod, self).__init__(**kwargs)

    # Preserve all argument data structures when saving/loading a config
    # (e.g., don't unnest lists that contain one element)
    self._preserve_input_structure_in_config = True

    self._expects_training_arg = False
    self._expects_mask_arg = False
示例#2
0
def register_dispatchers():
    """Constructs & registers OpDispatchers for ragged ops."""

    op_list = (_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
               _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
    for op in op_list:
        _, undecorated_op = tf_decorator.unwrap(op)
        if not hasattr(undecorated_op,
                       tf_export.API_ATTRS['tensorflow'].names):
            raise AssertionError('Expected %s to be an exported symbol '
                                 '(while adding a RaggedTensor dispatcher)')

    for op in _UNARY_ELEMENTWISE_OPS:
        UnaryRaggedElementwiseDispatcher(op).register(op)

    for op in _UNARY_LIST_ELEMENTWISE_OPS:
        UnaryRaggedElementwiseDispatcher(op, True).register(op)

    for op in _BINARY_ELEMENTWISE_OPS:
        BinaryRaggedElementwiseDispatcher(op).register(op)

    for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS:
        RaggedDispatcher(original_op, ragged_op, args).register(original_op)

    docstring = ('\n\n### Additional ops that support `RaggedTensor`\n\n' +
                 '\n'.join([
                     '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
                     for op in op_list
                 ]))

    return docstring
示例#3
0
def serialize_keras_object(obj):
    """Retrieve the config dict by serializing the Keras object.

    `serialize_keras_object()` serializes a Keras object to a python dictionary
    that represents the object, and is a reciprocal function of
    `deserialize_keras_object()`. See `deserialize_keras_object()` for more
    information about the config format.

    Args:
      obj: the Keras object to serialize.

    Returns:
      A python dict that represents the object. The python dict can be
      deserialized via `deserialize_keras_object()`.
    """

    # Note that in the case of the `obj` being a function, the module used will be
    # "builtins", and the `class_name` used will be "function"; in the case of the
    # `obj` being a string, the module used will be "builtins", and the
    # `class_name` used will be "str"
    module = None

    # This gets the `keras.*` exported name, such as "keras.optimizers.Adam".
    class_name = tf_export.get_canonical_name_for_symbol(obj.__class__,
                                                         api_name="keras")
    if class_name is None:
        module = obj.__class__.__module__
        class_name = obj.__class__.__name__
    return {
        "module": module,
        "class_name": class_name,
        "config": _get_object_config(obj),
        "registered_name": _get_object_registered_name(obj),
    }
示例#4
0
def _ragged_op_signature(op, ragged_args, ragged_varargs=False):
    """Returns a signature for the given op, marking ragged args in bold."""
    op_name = tf_export.get_canonical_name_for_symbol(op)
    argspec = tf_inspect.getfullargspec(op)
    arg_names = argspec.args

    # Mark ragged arguments in bold.
    for pos in ragged_args:
        arg_names[pos] = '**' + arg_names[pos] + '**'

    # Add argument defaults.
    if argspec.defaults is not None:
        for pos in range(-1, -len(argspec.defaults) - 1, -1):
            arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])

    # Add varargs and keyword args
    if argspec.varargs:
        if ragged_varargs:
            arg_names.append('***' + argspec.varargs + '**')
        else:
            arg_names.append('*' + argspec.varargs)
    if argspec.varkw:
        arg_names.append('**' + argspec.varkw)

    return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
示例#5
0
def _add_elementwise_ops_to_this_module(specs, verbose=False):
    """Adds ragged versions of the given ops to this module.

  Args:
    specs: A list of tuples containing the arguments for `make_elementwise_op`.
    verbose: If true, then display each op that gets added.
  """
    for spec in specs:
        original_op = spec[0]
        ragged_op = make_elementwise_op(*spec)
        canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
        if '.' not in canonical_name:
            op_name = canonical_name
        else:
            op_name = original_op.__name__

        # Temporary hack (will be removed once dispatch is added for RaggedTensors):
        if op_name == 'neg': op_name = 'negative'

        if verbose:
            print(
                'Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
                (op_name, canonical_name))
        globals()[op_name] = ragged_op
        _symbols_to_export.append(op_name)
示例#6
0
    def handle(self, op, args, kwargs):
        # Dispatcher only applies if at least one arg is a TensorTracer.
        if not (any(self.is_tensor_tracer_arg(x) for x in args)
                or any(self.is_tensor_tracer_arg(x) for x in kwargs.values())):
            return self.NOT_SUPPORTED

        symbol_name = get_canonical_name_for_symbol(op)
        return TensorTracer(symbol_name, args, kwargs)
示例#7
0
  def __init__(self, function, **kwargs):
    self.function = function
    self.symbol = (
        get_canonical_name_for_symbol(
            self.function, add_prefix_to_v1_names=True) or
        get_canonical_name_for_symbol(
            self.function, api_name='keras', add_prefix_to_v1_names=True))
    if 'name' not in kwargs:
      # Generate a name.
      # TFOpLambda layers avoid already-observed names,
      # because users cannot easily control the generated names.
      # Without this avoidance, users would be more likely to run
      # into unavoidable duplicate layer name collisions.
      # (For standard layers users could just set `name` when creating the
      # layer to work around a collision, but they can't do that for
      # auto-generated layers)
      if self.symbol:
        name = 'tf.' + self.symbol
      else:
        name = self.function.__name__
      kwargs['name'] = backend.unique_object_name(
          name, zero_based=True, avoid_observed_names=True)
    kwargs['autocast'] = False

    # Decorate the function to produce this layer's call method
    def _call_wrapper(*args, **kwargs):
      return self._call_wrapper(*args, **kwargs)

    self.call = tf.__internal__.decorator.make_decorator(
        function, _call_wrapper)

    # Do not individually trace op layers in the SavedModel.
    self._must_restore_from_config = True

    super(TFOpLambda, self).__init__(**kwargs)

    # Preserve all argument data structures when saving/loading a config
    # (e.g., don't unnest lists that contain one element)
    self._preserve_input_structure_in_config = True

    # Warning on every invocation will be quite irksome in Eager mode.
    self._already_warned = False

    self._expects_training_arg = False
    self._expects_mask_arg = False
示例#8
0
  def _score_name(self, name):
    canonical = tf_export.get_canonical_name_for_symbol(self._index[name])

    canonical_score = 1
    if canonical is not None and name == "tf." + canonical:
      canonical_score = -1

    scores = super(TfExportAwareDocGeneratorVisitor, self)._score_name(name)
    return (canonical_score,) + scores
示例#9
0
  def _score_name(self, name):
    canonical = tf_export.get_canonical_name_for_symbol(self._index[name])

    canonical_score = 1
    if canonical is not None and name == "tf." + canonical:
      canonical_score = -1

    scores = super(TfExportAwareDocGeneratorVisitor, self)._score_name(name)
    return (canonical_score,) + scores
示例#10
0
def ragged_op_list():
    """Returns a string listing operators that have dispathers registered."""
    op_list = (_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS +
               _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS])
    return ('\n\n### Additional ops that support `RaggedTensor`\n\n' +
            '\n'.join([
                '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op)
                for op in op_list
            ]))
示例#11
0
def _update_docstring_with_api_list(target, api_list):
    """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs."""
    lines = []
    for func in api_list:
        name = tf_export_lib.get_canonical_name_for_symbol(
            func, add_prefix_to_v1_names=True)
        if name is not None:
            signature = tf_inspect.signature(func)
            lines.append(f"  * `tf.{name}{signature}`")
    lines.sort()
    target.__doc__ = target.__doc__.replace("  <<API_LIST>>", "\n".join(lines))
 def testExportMultipleFunctions(self):
     export_decorator1 = tf_export.tf_export('nameA', 'nameB')
     export_decorator2 = tf_export.tf_export('nameC', 'nameD')
     decorated_function1 = export_decorator1(_test_function)
     decorated_function2 = export_decorator2(_test_function2)
     self.assertEqual(decorated_function1, _test_function)
     self.assertEqual(decorated_function2, _test_function2)
     self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names)
     self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names)
     self.assertEqual(tf_export.get_symbol_from_name('nameB'),
                      decorated_function1)
     self.assertEqual(tf_export.get_symbol_from_name('nameD'),
                      decorated_function2)
     self.assertEqual(
         tf_export.get_symbol_from_name(
             tf_export.get_canonical_name_for_symbol(decorated_function1)),
         decorated_function1)
     self.assertEqual(
         tf_export.get_symbol_from_name(
             tf_export.get_canonical_name_for_symbol(decorated_function2)),
         decorated_function2)
示例#13
0
  def _score_name(self, name):
    all_exports = [tf_export.TENSORFLOW_API_NAME, tf_export.ESTIMATOR_API_NAME]

    for api_name in all_exports:
      canonical = tf_export.get_canonical_name_for_symbol(
          self._index[name], api_name=api_name)
      if canonical is not None:
        break

    canonical_score = 1
    if canonical is not None and name == "tf." + canonical:
      canonical_score = -1

    scores = super()._score_name(name)
    return (canonical_score,) + scores
示例#14
0
  def _score_name(self, path: doc_generator_visitor.ApiPath) -> TfNameScore:
    name = ".".join(path)
    all_exports = [tf_export.TENSORFLOW_API_NAME,
                   tf_export.KERAS_API_NAME,
                   tf_export.ESTIMATOR_API_NAME]

    for api_name in all_exports:
      canonical = tf_export.get_canonical_name_for_symbol(
          self._index[name], api_name=api_name)
      if canonical is not None:
        break

    canonical_score = 1
    if canonical is not None and name == "tf." + canonical:
      canonical_score = -1

    return self.TfNameScore(canonical_score, super()._score_name(path))
示例#15
0
 def testExportSingleFunction(self):
   export_decorator = tf_export.tf_export('nameA', 'nameB')
   decorated_function = export_decorator(_test_function)
   self.assertEquals(decorated_function, _test_function)
   self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names)
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v1_names(decorated_function))
   self.assertEquals(['nameA', 'nameB'],
                     tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(decorated_function)),
       decorated_function)
示例#16
0
 def testExportSingleFunctionV1Only(self):
   export_decorator = tf_export.tf_export(v1=['nameA', 'nameB'])
   decorated_function = export_decorator(_test_function)
   self.assertEqual(decorated_function, _test_function)
   self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1)
   self.assertAllEqual(['nameA', 'nameB'],
                       tf_export.get_v1_names(decorated_function))
   self.assertEqual([],
                    tf_export.get_v2_names(decorated_function))
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'),
                    decorated_function)
   self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'),
                    decorated_function)
   self.assertEqual(
       tf_export.get_symbol_from_name(
           tf_export.get_canonical_name_for_symbol(
               decorated_function, add_prefix_to_v1_names=True)),
       decorated_function)
def _add_elementwise_ops_to_this_module(specs, verbose=False):
  """Adds ragged versions of the given ops to this module.

  Args:
    specs: A list of tuples containing the arguments for `make_elementwise_op`.
    verbose: If true, then display each op that gets added.
  """
  for spec in specs:
    original_op = spec[0]
    ragged_op = make_elementwise_op(*spec)
    canonical_name = tf_export.get_canonical_name_for_symbol(original_op)
    if '.' not in canonical_name:
      op_name = canonical_name
    else:
      op_name = original_op.__name__
    if verbose:
      print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' %
            (op_name, canonical_name))
    globals()[op_name] = ragged_op
    _symbols_to_export.append(op_name)
示例#18
0
def _ragged_op_signature(op, ragged_args):
  """Returns a signature for the given op, marking ragged args in bold."""
  op_name = tf_export.get_canonical_name_for_symbol(op)
  argspec = tf_inspect.getfullargspec(op)
  arg_names = argspec.args

  # Mark ragged arguments in bold.
  for pos in ragged_args:
    arg_names[pos] = '**' + arg_names[pos] + '**'

  # Add argument defaults.
  for pos in range(-1, -len(argspec.defaults) - 1, -1):
    arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos])

  # Add varargs and keyword args
  if argspec.varargs:
    arg_names.append('*' + argspec.varargs)
  if argspec.varkw:
    arg_names.append('**' + argspec.varkw)

  return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
示例#19
0
def add_imports_for_symbol(module_code_builder,
                           symbol,
                           source_module_name,
                           source_name,
                           api_name,
                           api_version,
                           output_module_prefix='',
                           decorators=None):
    """Add imports for the given symbol to `module_code_builder`.

  Args:
    module_code_builder: `_ModuleInitCodeBuilder` instance.
    symbol: A symbol.
    source_module_name: Module that we can import the symbol from.
    source_name: Name we can import the symbol with.
    api_name: API name. Currently, must be either `tensorflow` or `estimator`.
    api_version: API version.
    output_module_prefix: Prefix to prepend to destination module.
    decorators: Tuple of symbol's decorators.
  """
    names_attr_v2 = API_ATTRS[api_name].names
    constants_attr_v2 = API_ATTRS[api_name].constants
    if api_version == 1:
        names_attr = API_ATTRS_V1[api_name].names
        constants_attr = API_ATTRS_V1[api_name].constants
    else:
        names_attr = names_attr_v2
        constants_attr = constants_attr_v2

    # If symbol is _tf_api_constants attribute, then add the constants.
    if source_name == constants_attr:
        for exports, name in symbol:
            for export in exports:
                dest_module, dest_name = _get_name_and_module(export)
                dest_module = _join_modules(output_module_prefix, dest_module)
                module_code_builder.add_import(-1, dest_module,
                                               source_module_name, name,
                                               dest_name)

    # If symbol has _tf_api_names attribute, then add import for it.
    if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__):
        # Get a list of all V2 names if we generate V1 API to check for
        # deprecations.
        exports_v2 = []
        if api_version == 1 and hasattr(symbol, names_attr_v2):
            exports_v2 = getattr(symbol, names_attr_v2)
        canonical_endpoint = None

        # Generate import statements for symbols.
        for export in getattr(symbol, names_attr):  # pylint: disable=protected-access
            dest_module, dest_name = _get_name_and_module(export)
            dest_module = _join_modules(output_module_prefix, dest_module)
            module_code_builder.add_import(id(symbol), dest_module,
                                           source_module_name, source_name,
                                           dest_name)
            # Export is deprecated if it is not in 2.0.
            if (export not in exports_v2
                    and not dest_module.startswith(_COMPAT_MODULE_PREFIX)
                    and not has_deprecation_decorator(symbol, decorators)):
                if not canonical_endpoint:
                    canonical_endpoint = tf_export.get_canonical_name_for_symbol(
                        symbol, api_name, True)
                module_code_builder.add_deprecated_endpoint(
                    dest_module, dest_name, canonical_endpoint)
示例#20
0
  def _maybe_find_duplicates(self):
    """Compute data structures containing information about duplicates.

    Find duplicates in `index` and decide on one to be the "master" name.

    Computes a reverse_index mapping each object id to its master name.

    Also computes a map `duplicate_of` from aliases to their master name (the
    master name itself has no entry in this map), and a map `duplicates` from
    master names to a lexicographically sorted list of all aliases for that name
    (incl. the master name).

    All these are computed and set as fields if they haven't already.
    """
    if self._reverse_index is not None:
      return

    # Maps the id of a symbol to its fully qualified name. For symbols that have
    # several aliases, this map contains the first one found.
    # We use id(py_object) to get a hashable value for py_object. Note all
    # objects in _index are in memory at the same time so this is safe.
    reverse_index = {}

    # Make a preliminary duplicates map. For all sets of duplicate names, it
    # maps the first name found to a list of all duplicate names.
    raw_duplicates = {}
    for full_name, py_object in six.iteritems(self._index):
      # We cannot use the duplicate mechanism for some constants, since e.g.,
      # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants
      # have no usable docstring and won't be documented automatically.
      if (py_object is not None and
          not isinstance(py_object, six.integer_types + six.string_types +
                         (six.binary_type, six.text_type, float, complex, bool))
          and py_object is not ()):  # pylint: disable=literal-comparison
        object_id = id(py_object)
        if object_id in reverse_index:
          master_name = reverse_index[object_id]
          if master_name in raw_duplicates:
            raw_duplicates[master_name].append(full_name)
          else:
            raw_duplicates[master_name] = [master_name, full_name]
        else:
          reverse_index[object_id] = full_name
    # Decide on master names, rewire duplicates and make a duplicate_of map
    # mapping all non-master duplicates to the master name. The master symbol
    # does not have an entry in this map.
    duplicate_of = {}
    # Duplicates maps the main symbols to the set of all duplicates of that
    # symbol (incl. itself).
    duplicates = {}
    for names in raw_duplicates.values():
      names = sorted(names)
      master_name = (
          tf_export.get_canonical_name_for_symbol(self._index[names[0]])
          if names else None)
      if master_name:
        master_name = 'tf.%s' % master_name
      else:
        # Choose the master name with a lexical sort on the tuples returned by
        # by _score_name.
        master_name = min(names, key=self._score_name)

      duplicates[master_name] = names
      for name in names:
        if name != master_name:
          duplicate_of[name] = master_name

      # Set the reverse index to the canonical name.
      reverse_index[id(self._index[master_name])] = master_name

    self._duplicate_of = duplicate_of
    self._duplicates = duplicates
    self._reverse_index = reverse_index
def make_elementwise_op(op, *elementwise_args):
    """Returns a ragged-tensor version of the elementwise operation `op`.

  The returned operation will:

  1. Broadcast the elementwise arguments to have a compatible shape.
     An exception is raised if the tensors not broadcast-compatible.
  2. Call `op`, substituting the dense values of the broadcasted tensor for
     each elementwise argument.
  3. Return a potentially ragged tensor constructed from the output of `op`
     and the broadcasted tensors' nested row splits.

  For example, you can construct a ragged-tensor version of the standard
  operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.

  Args:
    op: The operation to wrap.
    *elementwise_args: The names of arguments to `op` that are treated as
      elementwise.  Arguments that take a list of tensors should have their
      names wrapped in square brackets (e.g. "[inputs]").

  Raises:
    ValueError: If any name specified in `elementwise_args` is not the name
      of an argument to `op`.
  """
    elementwise_arg_infos = _get_arg_infos(op, elementwise_args)

    def ragged_op(*args, **kwargs):
        """Ragged version of `op`."""
        args = list(args)

        # Collect all of the elementwise arguments, and put them in a single
        # dict whose values are the (potentially ragged) tensors that need to
        # be broadcast to a common shape.  The keys of this dict are tuples
        # (argkey, index), where argkey is an int for poitional args or a string
        # for keyword args; and index is None for non-list args and the index of the
        # tensor for list args.
        elementwise_args = {}
        for (name, position, is_list) in elementwise_arg_infos.values():
            if position < len(args):
                if is_list:
                    args[position] = list(args[position])
                    for (index, arg) in enumerate(args[position]):
                        elementwise_args[position, index] = arg
                else:
                    elementwise_args[position, None] = args[position]
            elif name in kwargs:
                if is_list:
                    kwargs[name] = list(kwargs[name])
                    for (i, arg) in enumerate(kwargs[name]):
                        elementwise_args[name, i] = arg
                else:
                    elementwise_args[name, None] = kwargs[name]

        with ops.name_scope(None, op.__name__, elementwise_args.values()):
            # Convert all inputs to tensors or ragged tensors.
            for ((key, index), tensor) in elementwise_args.items():
                argname = elementwise_arg_infos[key].name
                converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
                    tensor, name=argname)
                elementwise_args[key, index] = converted

            # Broadcast tensors to have compatible shapes.
            broadcast_args, result_splits, broadcast_check_ops = \
                _broadcast_elementwise_args(elementwise_args)

            # Replace tensor arguments with their dense values.
            for ((key, index), tensor) in broadcast_args.items():
                if ragged_tensor.is_ragged(tensor):
                    if isinstance(key, int) and index is None:
                        args[key] = tensor.inner_values
                    elif isinstance(key, int) and index is not None:
                        args[key][index] = tensor.inner_values
                    elif isinstance(key, str) and index is None:
                        kwargs[key] = tensor.inner_values
                    else:
                        assert isinstance(key, str) and index is not None
                        kwargs[key][index] = tensor.inner_values

            # Call the elementwise op on the broadcasted dense values.
            with ops.control_dependencies(broadcast_check_ops):
                result_values = op(*args, **kwargs)

            # Restore any ragged dimensions that we stripped off, and return the
            # result.
            return ragged_factory_ops.from_nested_row_splits(
                result_values, result_splits)

    # Construct the docstring.
    op_name = tf_export.get_canonical_name_for_symbol(op)
    assert op_name is not None, op
    argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
    docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name,
                                              argnames=argnames)

    # Update name, docstring, signature, etc., for the wrapper, and return it.
    return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
def make_elementwise_op(op, *elementwise_args):
  """Returns a ragged-tensor version of the elementwise operation `op`.

  The returned operation will:

  1. Broadcast the elementwise arguments to have a compatible shape.
     An exception is raised if the tensors not broadcast-compatible.
  2. Call `op`, substituting the dense values of the broadcasted tensor for
     each elementwise argument.
  3. Return a potentially ragged tensor constructed from the output of `op`
     and the broadcasted tensors' nested row splits.

  For example, you can construct a ragged-tensor version of the standard
  operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`.

  Args:
    op: The operation to wrap.
    *elementwise_args: The names of arguments to `op` that are treated as
      elementwise.  Arguments that take a list of tensors should have their
      names wrapped in square brackets (e.g. "[inputs]").

  Raises:
    ValueError: If any name specified in `elementwise_args` is not the name
      of an argument to `op`.
  """
  elementwise_arg_infos = _get_arg_infos(op, elementwise_args)

  def ragged_op(*args, **kwargs):
    """Ragged version of `op`."""
    args = list(args)

    # Collect all of the elementwise arguments, and put them in a single
    # dict whose values are the (potentially ragged) tensors that need to
    # be broadcast to a common shape.  The keys of this dict are tuples
    # (argkey, index), where argkey is an int for poitional args or a string
    # for keyword args; and index is None for non-list args and the index of the
    # tensor for list args.
    elementwise_args = {}
    for (name, position, is_list) in elementwise_arg_infos.values():
      if position < len(args):
        if is_list:
          args[position] = list(args[position])
          for (index, arg) in enumerate(args[position]):
            elementwise_args[position, index] = arg
        else:
          elementwise_args[position, None] = args[position]
      elif name in kwargs:
        if is_list:
          kwargs[name] = list(kwargs[name])
          for (i, arg) in enumerate(kwargs[name]):
            elementwise_args[name, i] = arg
        else:
          elementwise_args[name, None] = kwargs[name]

    with ops.name_scope(None, op.__name__, elementwise_args.values()):
      # Convert all inputs to tensors or ragged tensors.
      for ((key, index), tensor) in elementwise_args.items():
        argname = elementwise_arg_infos[key].name
        converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor(
            tensor, name=argname)
        elementwise_args[key, index] = converted

      # Broadcast tensors to have compatible shapes.
      broadcast_args, result_splits, broadcast_check_ops = \
          _broadcast_elementwise_args(elementwise_args)

      # Replace tensor arguments with their dense values.
      for ((key, index), tensor) in broadcast_args.items():
        if ragged_tensor.is_ragged(tensor):
          if isinstance(key, int) and index is None:
            args[key] = tensor.inner_values
          elif isinstance(key, int) and index is not None:
            args[key][index] = tensor.inner_values
          elif isinstance(key, str) and index is None:
            kwargs[key] = tensor.inner_values
          else:
            assert isinstance(key, str) and index is not None
            kwargs[key][index] = tensor.inner_values

      # Call the elementwise op on the broadcasted dense values.
      with ops.control_dependencies(broadcast_check_ops):
        result_values = op(*args, **kwargs)

      # Restore any ragged dimensions that we stripped off, and return the
      # result.
      return ragged_factory_ops.from_nested_row_splits(result_values,
                                                       result_splits)

  # Construct the docstring.
  op_name = tf_export.get_canonical_name_for_symbol(op)
  assert op_name is not None, op
  argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args)
  docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames)

  # Update name, docstring, signature, etc., for the wrapper, and return it.
  return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
  def _maybe_find_duplicates(self):
    """Compute data structures containing information about duplicates.

    Find duplicates in `index` and decide on one to be the "master" name.

    Computes a reverse_index mapping each object id to its master name.

    Also computes a map `duplicate_of` from aliases to their master name (the
    master name itself has no entry in this map), and a map `duplicates` from
    master names to a lexicographically sorted list of all aliases for that name
    (incl. the master name).

    All these are computed and set as fields if they haven't already.
    """
    if self._reverse_index is not None:
      return

    # Maps the id of a symbol to its fully qualified name. For symbols that have
    # several aliases, this map contains the first one found.
    # We use id(py_object) to get a hashable value for py_object. Note all
    # objects in _index are in memory at the same time so this is safe.
    reverse_index = {}

    # Make a preliminary duplicates map. For all sets of duplicate names, it
    # maps the first name found to a list of all duplicate names.
    raw_duplicates = {}
    for full_name, py_object in six.iteritems(self._index):
      # We cannot use the duplicate mechanism for some constants, since e.g.,
      # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants
      # have no usable docstring and won't be documented automatically.
      if (py_object is not None and
          not isinstance(py_object, six.integer_types + six.string_types +
                         (six.binary_type, six.text_type, float, complex, bool))
          and py_object is not ()):
        object_id = id(py_object)
        if object_id in reverse_index:
          master_name = reverse_index[object_id]
          if master_name in raw_duplicates:
            raw_duplicates[master_name].append(full_name)
          else:
            raw_duplicates[master_name] = [master_name, full_name]
        else:
          reverse_index[object_id] = full_name
    # Decide on master names, rewire duplicates and make a duplicate_of map
    # mapping all non-master duplicates to the master name. The master symbol
    # does not have an entry in this map.
    duplicate_of = {}
    # Duplicates maps the main symbols to the set of all duplicates of that
    # symbol (incl. itself).
    duplicates = {}
    for names in raw_duplicates.values():
      names = sorted(names)
      master_name = (
          tf_export.get_canonical_name_for_symbol(self._index[names[0]])
          if names else None)
      if master_name:
        master_name = 'tf.%s' % master_name
      else:
        # Choose the lexicographically first name with the minimum number of
        # submodules. This will prefer highest level namespace for any symbol.
        master_name = min(names, key=lambda name: name.count('.'))

      duplicates[master_name] = names
      for name in names:
        if name != master_name:
          duplicate_of[name] = master_name

      # Set the reverse index to the canonical name.
      reverse_index[id(self._index[master_name])] = master_name

    self._duplicate_of = duplicate_of
    self._duplicates = duplicates
    self._reverse_index = reverse_index
class DocGeneratorVisitor(object):
  """A visitor that generates docs for a python object when __call__ed."""

  def __init__(self, root_name=''):
    """Make a visitor.

    As this visitor is starting its traversal at a module or class, it will not
    be told the name of that object during traversal. `root_name` is the name it
    should use for that object, effectively prefixing all names with
    "root_name.".

    Args:
      root_name: The name of the root module/class.
    """
    self.set_root_name(root_name)
    self._index = {}
    self._tree = {}
    self._reverse_index = None
    self._duplicates = None
    self._duplicate_of = None

  def set_root_name(self, root_name):
    """Sets the root name for subsequent __call__s."""
    self._root_name = root_name or ''
    self._prefix = (root_name + '.') if root_name else ''

  @property
  def index(self):
    """A map from fully qualified names to objects to be documented.

    The index is filled when the visitor is passed to `traverse`.

    Returns:
      The index filled by traversal.
    """
    return self._index

  @property
  def tree(self):
    """A map from fully qualified names to all its child names for traversal.

    The full name to member names map is filled when the visitor is passed to
    `traverse`.

    Returns:
      The full name to member name map filled by traversal.
    """
    return self._tree

  @property
  def reverse_index(self):
    """A map from `id(object)` to the preferred fully qualified name.

    This map only contains non-primitive objects (no numbers or strings) present
    in `index` (for primitive objects, `id()` doesn't quite do the right thing).

    It is computed when it, `duplicate_of`, or `duplicates` are first accessed.

    Returns:
      The `id(object)` to full name map.
    """
    self._maybe_find_duplicates()
    return self._reverse_index

  @property
  def duplicate_of(self):
    """A map from duplicate full names to a preferred fully qualified name.

    This map only contains names that are not themself a preferred name.

    It is computed when it, `reverse_index`, or `duplicates` are first accessed.

    Returns:
      The map from duplicate name to preferred name.
    """
    self._maybe_find_duplicates()
    return self._duplicate_of

  @property
  def duplicates(self):
    """A map from preferred full names to a list of all names for this symbol.

    This function returns a map from preferred (master) name for a symbol to a
    lexicographically sorted list of all aliases for that name (incl. the master
    name). Symbols without duplicate names do not appear in this map.

    It is computed when it, `reverse_index`, or `duplicate_of` are first
    accessed.

    Returns:
      The map from master name to list of all duplicate names.
    """
    self._maybe_find_duplicates()
    return self._duplicates

  def _add_prefix(self, name):
    """Adds the root name to a name."""
    return self._prefix + name if name else self._root_name

  def __call__(self, parent_name, parent, children):
    """Visitor interface, see `tensorflow/tools/common:traverse` for details.

    This method is called for each symbol found in a traversal using
    `tensorflow/tools/common:traverse`. It should not be called directly in
    user code.

    Args:
      parent_name: The fully qualified name of a symbol found during traversal.
      parent: The Python object referenced by `parent_name`.
      children: A list of `(name, py_object)` pairs enumerating, in alphabetical
        order, the children (as determined by `tf_inspect.getmembers`) of
          `parent`. `name` is the local name of `py_object` in `parent`.

    Raises:
      RuntimeError: If this visitor is called with a `parent` that is not a
        class or module.
    """
    parent_name = self._add_prefix(parent_name)
    self._index[parent_name] = parent
    self._tree[parent_name] = []

    if not (tf_inspect.ismodule(parent) or tf_inspect.isclass(parent)):
      raise RuntimeError('Unexpected type in visitor -- %s: %r' % (parent_name,
                                                                   parent))

    for i, (name, child) in enumerate(list(children)):
      # Don't document __metaclass__
      if name in ['__metaclass__']:
        del children[i]
        continue

      full_name = '.'.join([parent_name, name]) if parent_name else name
      self._index[full_name] = child
      self._tree[parent_name].append(name)

  def _score_name(self, name):
    """Return a tuple of scores indicating how to sort for the best name.

    This function is meant to be used as the `key` to the `sorted` function.

    This sorting in order:
      Prefers names refering to the defining class, over a subclass.
      Prefers names that are not in "contrib".
      prefers submodules to the root namespace.
      Prefers short names `tf.thing` over `tf.a.b.c.thing`
      Sorts lexicographically on name parts.

    Args:
      name: the full name to score, for example `tf.estimator.Estimator`

    Returns:
      A tuple of scores. When sorted the preferred name will have the lowest
      value.
    """
    parts = name.split('.')
    short_name = parts[-1]

    container = self._index['.'.join(parts[:-1])]

    defining_class_score = 1
    if tf_inspect.isclass(container):
      if short_name in container.__dict__:
        # prefer the defining class
        defining_class_score = -1

    contrib_score = -1
    if 'contrib' in parts:
      contrib_score = 1

    while parts:
      container = self._index['.'.join(parts)]
      if tf_inspect.ismodule(container):
        break
      parts.pop()

    module_length = len(parts)
    if len(parts) == 2:
      # `tf.submodule.thing` is better than `tf.thing`
      module_length_score = -1
    else:
      # shorter is better
      module_length_score = module_length

    return (defining_class_score, contrib_score, module_length_score, name)

  def _maybe_find_duplicates(self):
    """Compute data structures containing information about duplicates.

    Find duplicates in `index` and decide on one to be the "master" name.

    Computes a reverse_index mapping each object id to its master name.

    Also computes a map `duplicate_of` from aliases to their master name (the
    master name itself has no entry in this map), and a map `duplicates` from
    master names to a lexicographically sorted list of all aliases for that name
    (incl. the master name).

    All these are computed and set as fields if they haven't already.
    """
    if self._reverse_index is not None:
      return

    # Maps the id of a symbol to its fully qualified name. For symbols that have
    # several aliases, this map contains the first one found.
    # We use id(py_object) to get a hashable value for py_object. Note all
    # objects in _index are in memory at the same time so this is safe.
    reverse_index = {}

    # Make a preliminary duplicates map. For all sets of duplicate names, it
    # maps the first name found to a list of all duplicate names.
    raw_duplicates = {}
    for full_name, py_object in six.iteritems(self._index):
      # We cannot use the duplicate mechanism for some constants, since e.g.,
      # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants
      # have no usable docstring and won't be documented automatically.
      if (py_object not in (None, ())
          not isinstance(py_object, six.integer_types + six.string_types +
                         (six.binary_type, six.text_type, float, complex, bool))):
        object_id = id(py_object)
        if object_id in reverse_index:
          master_name = reverse_index[object_id]
          if master_name in raw_duplicates:
            raw_duplicates[master_name].append(full_name)
          else:
            raw_duplicates[master_name] = [master_name, full_name]
        else:
          reverse_index[object_id] = full_name
    # Decide on master names, rewire duplicates and make a duplicate_of map
    # mapping all non-master duplicates to the master name. The master symbol
    # does not have an entry in this map.
    duplicate_of = {}
    # Duplicates maps the main symbols to the set of all duplicates of that
    # symbol (incl. itself).
    duplicates = {}
    for names in raw_duplicates.values():
      names = sorted(names)
      master_name = (
          tf_export.get_canonical_name_for_symbol(self._index[names[0]])
          if names else None)
      if master_name:
        master_name = 'tf.%s' % master_name
      else:
        # Choose the master name with a lexical sort on the tuples returned by
        # by _score_name.
        master_name = min(names, key=self._score_name)

      duplicates[master_name] = names
      for name in names:
        if name != master_name:
          duplicate_of[name] = master_name

      # Set the reverse index to the canonical name.
      reverse_index[id(self._index[master_name])] = master_name

    self._duplicate_of = duplicate_of
    self._duplicates = duplicates
    self._reverse_index = reverse_index