def __init__(self, cls_ref, method_name, **kwargs): self.cls_ref = cls_ref self.method_name = method_name self.cls_symbol = ( get_canonical_name_for_symbol( self.cls_ref, add_prefix_to_v1_names=True) or get_canonical_name_for_symbol( self.cls_ref, api_name='keras', add_prefix_to_v1_names=True)) if 'name' not in kwargs: kwargs['name'] = backend.unique_object_name( 'tf.' + self.cls_symbol + '.' + self.method_name, zero_based=True, avoid_observed_names=True) kwargs['autocast'] = False # Do not individually trace op layers in the SavedModel. self._must_restore_from_config = True super(ClassMethod, self).__init__(**kwargs) # Preserve all argument data structures when saving/loading a config # (e.g., don't unnest lists that contain one element) self._preserve_input_structure_in_config = True self._expects_training_arg = False self._expects_mask_arg = False
def register_dispatchers(): """Constructs & registers OpDispatchers for ragged ops.""" op_list = (_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) for op in op_list: _, undecorated_op = tf_decorator.unwrap(op) if not hasattr(undecorated_op, tf_export.API_ATTRS['tensorflow'].names): raise AssertionError('Expected %s to be an exported symbol ' '(while adding a RaggedTensor dispatcher)') for op in _UNARY_ELEMENTWISE_OPS: UnaryRaggedElementwiseDispatcher(op).register(op) for op in _UNARY_LIST_ELEMENTWISE_OPS: UnaryRaggedElementwiseDispatcher(op, True).register(op) for op in _BINARY_ELEMENTWISE_OPS: BinaryRaggedElementwiseDispatcher(op).register(op) for (original_op, ragged_op, args) in _RAGGED_DISPATCH_OPS: RaggedDispatcher(original_op, ragged_op, args).register(original_op) docstring = ('\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([ '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op) for op in op_list ])) return docstring
def serialize_keras_object(obj): """Retrieve the config dict by serializing the Keras object. `serialize_keras_object()` serializes a Keras object to a python dictionary that represents the object, and is a reciprocal function of `deserialize_keras_object()`. See `deserialize_keras_object()` for more information about the config format. Args: obj: the Keras object to serialize. Returns: A python dict that represents the object. The python dict can be deserialized via `deserialize_keras_object()`. """ # Note that in the case of the `obj` being a function, the module used will be # "builtins", and the `class_name` used will be "function"; in the case of the # `obj` being a string, the module used will be "builtins", and the # `class_name` used will be "str" module = None # This gets the `keras.*` exported name, such as "keras.optimizers.Adam". class_name = tf_export.get_canonical_name_for_symbol(obj.__class__, api_name="keras") if class_name is None: module = obj.__class__.__module__ class_name = obj.__class__.__name__ return { "module": module, "class_name": class_name, "config": _get_object_config(obj), "registered_name": _get_object_registered_name(obj), }
def _ragged_op_signature(op, ragged_args, ragged_varargs=False): """Returns a signature for the given op, marking ragged args in bold.""" op_name = tf_export.get_canonical_name_for_symbol(op) argspec = tf_inspect.getfullargspec(op) arg_names = argspec.args # Mark ragged arguments in bold. for pos in ragged_args: arg_names[pos] = '**' + arg_names[pos] + '**' # Add argument defaults. if argspec.defaults is not None: for pos in range(-1, -len(argspec.defaults) - 1, -1): arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) # Add varargs and keyword args if argspec.varargs: if ragged_varargs: arg_names.append('***' + argspec.varargs + '**') else: arg_names.append('*' + argspec.varargs) if argspec.varkw: arg_names.append('**' + argspec.varkw) return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
def _add_elementwise_ops_to_this_module(specs, verbose=False): """Adds ragged versions of the given ops to this module. Args: specs: A list of tuples containing the arguments for `make_elementwise_op`. verbose: If true, then display each op that gets added. """ for spec in specs: original_op = spec[0] ragged_op = make_elementwise_op(*spec) canonical_name = tf_export.get_canonical_name_for_symbol(original_op) if '.' not in canonical_name: op_name = canonical_name else: op_name = original_op.__name__ # Temporary hack (will be removed once dispatch is added for RaggedTensors): if op_name == 'neg': op_name = 'negative' if verbose: print( 'Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' % (op_name, canonical_name)) globals()[op_name] = ragged_op _symbols_to_export.append(op_name)
def handle(self, op, args, kwargs): # Dispatcher only applies if at least one arg is a TensorTracer. if not (any(self.is_tensor_tracer_arg(x) for x in args) or any(self.is_tensor_tracer_arg(x) for x in kwargs.values())): return self.NOT_SUPPORTED symbol_name = get_canonical_name_for_symbol(op) return TensorTracer(symbol_name, args, kwargs)
def __init__(self, function, **kwargs): self.function = function self.symbol = ( get_canonical_name_for_symbol( self.function, add_prefix_to_v1_names=True) or get_canonical_name_for_symbol( self.function, api_name='keras', add_prefix_to_v1_names=True)) if 'name' not in kwargs: # Generate a name. # TFOpLambda layers avoid already-observed names, # because users cannot easily control the generated names. # Without this avoidance, users would be more likely to run # into unavoidable duplicate layer name collisions. # (For standard layers users could just set `name` when creating the # layer to work around a collision, but they can't do that for # auto-generated layers) if self.symbol: name = 'tf.' + self.symbol else: name = self.function.__name__ kwargs['name'] = backend.unique_object_name( name, zero_based=True, avoid_observed_names=True) kwargs['autocast'] = False # Decorate the function to produce this layer's call method def _call_wrapper(*args, **kwargs): return self._call_wrapper(*args, **kwargs) self.call = tf.__internal__.decorator.make_decorator( function, _call_wrapper) # Do not individually trace op layers in the SavedModel. self._must_restore_from_config = True super(TFOpLambda, self).__init__(**kwargs) # Preserve all argument data structures when saving/loading a config # (e.g., don't unnest lists that contain one element) self._preserve_input_structure_in_config = True # Warning on every invocation will be quite irksome in Eager mode. self._already_warned = False self._expects_training_arg = False self._expects_mask_arg = False
def _score_name(self, name): canonical = tf_export.get_canonical_name_for_symbol(self._index[name]) canonical_score = 1 if canonical is not None and name == "tf." + canonical: canonical_score = -1 scores = super(TfExportAwareDocGeneratorVisitor, self)._score_name(name) return (canonical_score,) + scores
def _score_name(self, name): canonical = tf_export.get_canonical_name_for_symbol(self._index[name]) canonical_score = 1 if canonical is not None and name == "tf." + canonical: canonical_score = -1 scores = super(TfExportAwareDocGeneratorVisitor, self)._score_name(name) return (canonical_score,) + scores
def ragged_op_list(): """Returns a string listing operators that have dispathers registered.""" op_list = (_UNARY_ELEMENTWISE_OPS + _UNARY_LIST_ELEMENTWISE_OPS + _BINARY_ELEMENTWISE_OPS + [x[0] for x in _RAGGED_DISPATCH_OPS]) return ('\n\n### Additional ops that support `RaggedTensor`\n\n' + '\n'.join([ '* `tf.%s`' % tf_export.get_canonical_name_for_symbol(op) for op in op_list ]))
def _update_docstring_with_api_list(target, api_list): """Replaces `<<API_LIST>>` in target.__doc__ with the given list of APIs.""" lines = [] for func in api_list: name = tf_export_lib.get_canonical_name_for_symbol( func, add_prefix_to_v1_names=True) if name is not None: signature = tf_inspect.signature(func) lines.append(f" * `tf.{name}{signature}`") lines.sort() target.__doc__ = target.__doc__.replace(" <<API_LIST>>", "\n".join(lines))
def testExportMultipleFunctions(self): export_decorator1 = tf_export.tf_export('nameA', 'nameB') export_decorator2 = tf_export.tf_export('nameC', 'nameD') decorated_function1 = export_decorator1(_test_function) decorated_function2 = export_decorator2(_test_function2) self.assertEqual(decorated_function1, _test_function) self.assertEqual(decorated_function2, _test_function2) self.assertEqual(('nameA', 'nameB'), decorated_function1._tf_api_names) self.assertEqual(('nameC', 'nameD'), decorated_function2._tf_api_names) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function1) self.assertEqual(tf_export.get_symbol_from_name('nameD'), decorated_function2) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function1)), decorated_function1) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function2)), decorated_function2)
def _score_name(self, name): all_exports = [tf_export.TENSORFLOW_API_NAME, tf_export.ESTIMATOR_API_NAME] for api_name in all_exports: canonical = tf_export.get_canonical_name_for_symbol( self._index[name], api_name=api_name) if canonical is not None: break canonical_score = 1 if canonical is not None and name == "tf." + canonical: canonical_score = -1 scores = super()._score_name(name) return (canonical_score,) + scores
def _score_name(self, path: doc_generator_visitor.ApiPath) -> TfNameScore: name = ".".join(path) all_exports = [tf_export.TENSORFLOW_API_NAME, tf_export.KERAS_API_NAME, tf_export.ESTIMATOR_API_NAME] for api_name in all_exports: canonical = tf_export.get_canonical_name_for_symbol( self._index[name], api_name=api_name) if canonical is not None: break canonical_score = 1 if canonical is not None and name == "tf." + canonical: canonical_score = -1 return self.TfNameScore(canonical_score, super()._score_name(path))
def testExportSingleFunction(self): export_decorator = tf_export.tf_export('nameA', 'nameB') decorated_function = export_decorator(_test_function) self.assertEquals(decorated_function, _test_function) self.assertEquals(('nameA', 'nameB'), decorated_function._tf_api_names) self.assertEquals(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEquals(['nameA', 'nameB'], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol(decorated_function)), decorated_function)
def testExportSingleFunctionV1Only(self): export_decorator = tf_export.tf_export(v1=['nameA', 'nameB']) decorated_function = export_decorator(_test_function) self.assertEqual(decorated_function, _test_function) self.assertAllEqual(('nameA', 'nameB'), decorated_function._tf_api_names_v1) self.assertAllEqual(['nameA', 'nameB'], tf_export.get_v1_names(decorated_function)) self.assertEqual([], tf_export.get_v2_names(decorated_function)) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameA'), decorated_function) self.assertEqual(tf_export.get_symbol_from_name('compat.v1.nameB'), decorated_function) self.assertEqual( tf_export.get_symbol_from_name( tf_export.get_canonical_name_for_symbol( decorated_function, add_prefix_to_v1_names=True)), decorated_function)
def _add_elementwise_ops_to_this_module(specs, verbose=False): """Adds ragged versions of the given ops to this module. Args: specs: A list of tuples containing the arguments for `make_elementwise_op`. verbose: If true, then display each op that gets added. """ for spec in specs: original_op = spec[0] ragged_op = make_elementwise_op(*spec) canonical_name = tf_export.get_canonical_name_for_symbol(original_op) if '.' not in canonical_name: op_name = canonical_name else: op_name = original_op.__name__ if verbose: print('Adding ragged_elementwise_op: tf.ragged.%s (based on tf.%s)' % (op_name, canonical_name)) globals()[op_name] = ragged_op _symbols_to_export.append(op_name)
def _ragged_op_signature(op, ragged_args): """Returns a signature for the given op, marking ragged args in bold.""" op_name = tf_export.get_canonical_name_for_symbol(op) argspec = tf_inspect.getfullargspec(op) arg_names = argspec.args # Mark ragged arguments in bold. for pos in ragged_args: arg_names[pos] = '**' + arg_names[pos] + '**' # Add argument defaults. for pos in range(-1, -len(argspec.defaults) - 1, -1): arg_names[pos] += '=`{!r}`'.format(argspec.defaults[pos]) # Add varargs and keyword args if argspec.varargs: arg_names.append('*' + argspec.varargs) if argspec.varkw: arg_names.append('**' + argspec.varkw) return '* `tf.{}`({})'.format(op_name, ', '.join(arg_names))
def add_imports_for_symbol(module_code_builder, symbol, source_module_name, source_name, api_name, api_version, output_module_prefix='', decorators=None): """Add imports for the given symbol to `module_code_builder`. Args: module_code_builder: `_ModuleInitCodeBuilder` instance. symbol: A symbol. source_module_name: Module that we can import the symbol from. source_name: Name we can import the symbol with. api_name: API name. Currently, must be either `tensorflow` or `estimator`. api_version: API version. output_module_prefix: Prefix to prepend to destination module. decorators: Tuple of symbol's decorators. """ names_attr_v2 = API_ATTRS[api_name].names constants_attr_v2 = API_ATTRS[api_name].constants if api_version == 1: names_attr = API_ATTRS_V1[api_name].names constants_attr = API_ATTRS_V1[api_name].constants else: names_attr = names_attr_v2 constants_attr = constants_attr_v2 # If symbol is _tf_api_constants attribute, then add the constants. if source_name == constants_attr: for exports, name in symbol: for export in exports: dest_module, dest_name = _get_name_and_module(export) dest_module = _join_modules(output_module_prefix, dest_module) module_code_builder.add_import(-1, dest_module, source_module_name, name, dest_name) # If symbol has _tf_api_names attribute, then add import for it. if (hasattr(symbol, '__dict__') and names_attr in symbol.__dict__): # Get a list of all V2 names if we generate V1 API to check for # deprecations. exports_v2 = [] if api_version == 1 and hasattr(symbol, names_attr_v2): exports_v2 = getattr(symbol, names_attr_v2) canonical_endpoint = None # Generate import statements for symbols. for export in getattr(symbol, names_attr): # pylint: disable=protected-access dest_module, dest_name = _get_name_and_module(export) dest_module = _join_modules(output_module_prefix, dest_module) module_code_builder.add_import(id(symbol), dest_module, source_module_name, source_name, dest_name) # Export is deprecated if it is not in 2.0. if (export not in exports_v2 and not dest_module.startswith(_COMPAT_MODULE_PREFIX) and not has_deprecation_decorator(symbol, decorators)): if not canonical_endpoint: canonical_endpoint = tf_export.get_canonical_name_for_symbol( symbol, api_name, True) module_code_builder.add_deprecated_endpoint( dest_module, dest_name, canonical_endpoint)
def _maybe_find_duplicates(self): """Compute data structures containing information about duplicates. Find duplicates in `index` and decide on one to be the "master" name. Computes a reverse_index mapping each object id to its master name. Also computes a map `duplicate_of` from aliases to their master name (the master name itself has no entry in this map), and a map `duplicates` from master names to a lexicographically sorted list of all aliases for that name (incl. the master name). All these are computed and set as fields if they haven't already. """ if self._reverse_index is not None: return # Maps the id of a symbol to its fully qualified name. For symbols that have # several aliases, this map contains the first one found. # We use id(py_object) to get a hashable value for py_object. Note all # objects in _index are in memory at the same time so this is safe. reverse_index = {} # Make a preliminary duplicates map. For all sets of duplicate names, it # maps the first name found to a list of all duplicate names. raw_duplicates = {} for full_name, py_object in six.iteritems(self._index): # We cannot use the duplicate mechanism for some constants, since e.g., # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants # have no usable docstring and won't be documented automatically. if (py_object is not None and not isinstance(py_object, six.integer_types + six.string_types + (six.binary_type, six.text_type, float, complex, bool)) and py_object is not ()): # pylint: disable=literal-comparison object_id = id(py_object) if object_id in reverse_index: master_name = reverse_index[object_id] if master_name in raw_duplicates: raw_duplicates[master_name].append(full_name) else: raw_duplicates[master_name] = [master_name, full_name] else: reverse_index[object_id] = full_name # Decide on master names, rewire duplicates and make a duplicate_of map # mapping all non-master duplicates to the master name. The master symbol # does not have an entry in this map. duplicate_of = {} # Duplicates maps the main symbols to the set of all duplicates of that # symbol (incl. itself). duplicates = {} for names in raw_duplicates.values(): names = sorted(names) master_name = ( tf_export.get_canonical_name_for_symbol(self._index[names[0]]) if names else None) if master_name: master_name = 'tf.%s' % master_name else: # Choose the master name with a lexical sort on the tuples returned by # by _score_name. master_name = min(names, key=self._score_name) duplicates[master_name] = names for name in names: if name != master_name: duplicate_of[name] = master_name # Set the reverse index to the canonical name. reverse_index[id(self._index[master_name])] = master_name self._duplicate_of = duplicate_of self._duplicates = duplicates self._reverse_index = reverse_index
def make_elementwise_op(op, *elementwise_args): """Returns a ragged-tensor version of the elementwise operation `op`. The returned operation will: 1. Broadcast the elementwise arguments to have a compatible shape. An exception is raised if the tensors not broadcast-compatible. 2. Call `op`, substituting the dense values of the broadcasted tensor for each elementwise argument. 3. Return a potentially ragged tensor constructed from the output of `op` and the broadcasted tensors' nested row splits. For example, you can construct a ragged-tensor version of the standard operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`. Args: op: The operation to wrap. *elementwise_args: The names of arguments to `op` that are treated as elementwise. Arguments that take a list of tensors should have their names wrapped in square brackets (e.g. "[inputs]"). Raises: ValueError: If any name specified in `elementwise_args` is not the name of an argument to `op`. """ elementwise_arg_infos = _get_arg_infos(op, elementwise_args) def ragged_op(*args, **kwargs): """Ragged version of `op`.""" args = list(args) # Collect all of the elementwise arguments, and put them in a single # dict whose values are the (potentially ragged) tensors that need to # be broadcast to a common shape. The keys of this dict are tuples # (argkey, index), where argkey is an int for poitional args or a string # for keyword args; and index is None for non-list args and the index of the # tensor for list args. elementwise_args = {} for (name, position, is_list) in elementwise_arg_infos.values(): if position < len(args): if is_list: args[position] = list(args[position]) for (index, arg) in enumerate(args[position]): elementwise_args[position, index] = arg else: elementwise_args[position, None] = args[position] elif name in kwargs: if is_list: kwargs[name] = list(kwargs[name]) for (i, arg) in enumerate(kwargs[name]): elementwise_args[name, i] = arg else: elementwise_args[name, None] = kwargs[name] with ops.name_scope(None, op.__name__, elementwise_args.values()): # Convert all inputs to tensors or ragged tensors. for ((key, index), tensor) in elementwise_args.items(): argname = elementwise_arg_infos[key].name converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( tensor, name=argname) elementwise_args[key, index] = converted # Broadcast tensors to have compatible shapes. broadcast_args, result_splits, broadcast_check_ops = \ _broadcast_elementwise_args(elementwise_args) # Replace tensor arguments with their dense values. for ((key, index), tensor) in broadcast_args.items(): if ragged_tensor.is_ragged(tensor): if isinstance(key, int) and index is None: args[key] = tensor.inner_values elif isinstance(key, int) and index is not None: args[key][index] = tensor.inner_values elif isinstance(key, str) and index is None: kwargs[key] = tensor.inner_values else: assert isinstance(key, str) and index is not None kwargs[key][index] = tensor.inner_values # Call the elementwise op on the broadcasted dense values. with ops.control_dependencies(broadcast_check_ops): result_values = op(*args, **kwargs) # Restore any ragged dimensions that we stripped off, and return the # result. return ragged_factory_ops.from_nested_row_splits( result_values, result_splits) # Construct the docstring. op_name = tf_export.get_canonical_name_for_symbol(op) assert op_name is not None, op argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args) docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames) # Update name, docstring, signature, etc., for the wrapper, and return it. return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
def make_elementwise_op(op, *elementwise_args): """Returns a ragged-tensor version of the elementwise operation `op`. The returned operation will: 1. Broadcast the elementwise arguments to have a compatible shape. An exception is raised if the tensors not broadcast-compatible. 2. Call `op`, substituting the dense values of the broadcasted tensor for each elementwise argument. 3. Return a potentially ragged tensor constructed from the output of `op` and the broadcasted tensors' nested row splits. For example, you can construct a ragged-tensor version of the standard operation `tf.add` by calling `make_elementwise_op(tf.add, 'x', 'y')`. Args: op: The operation to wrap. *elementwise_args: The names of arguments to `op` that are treated as elementwise. Arguments that take a list of tensors should have their names wrapped in square brackets (e.g. "[inputs]"). Raises: ValueError: If any name specified in `elementwise_args` is not the name of an argument to `op`. """ elementwise_arg_infos = _get_arg_infos(op, elementwise_args) def ragged_op(*args, **kwargs): """Ragged version of `op`.""" args = list(args) # Collect all of the elementwise arguments, and put them in a single # dict whose values are the (potentially ragged) tensors that need to # be broadcast to a common shape. The keys of this dict are tuples # (argkey, index), where argkey is an int for poitional args or a string # for keyword args; and index is None for non-list args and the index of the # tensor for list args. elementwise_args = {} for (name, position, is_list) in elementwise_arg_infos.values(): if position < len(args): if is_list: args[position] = list(args[position]) for (index, arg) in enumerate(args[position]): elementwise_args[position, index] = arg else: elementwise_args[position, None] = args[position] elif name in kwargs: if is_list: kwargs[name] = list(kwargs[name]) for (i, arg) in enumerate(kwargs[name]): elementwise_args[name, i] = arg else: elementwise_args[name, None] = kwargs[name] with ops.name_scope(None, op.__name__, elementwise_args.values()): # Convert all inputs to tensors or ragged tensors. for ((key, index), tensor) in elementwise_args.items(): argname = elementwise_arg_infos[key].name converted = ragged_factory_ops.convert_to_tensor_or_ragged_tensor( tensor, name=argname) elementwise_args[key, index] = converted # Broadcast tensors to have compatible shapes. broadcast_args, result_splits, broadcast_check_ops = \ _broadcast_elementwise_args(elementwise_args) # Replace tensor arguments with their dense values. for ((key, index), tensor) in broadcast_args.items(): if ragged_tensor.is_ragged(tensor): if isinstance(key, int) and index is None: args[key] = tensor.inner_values elif isinstance(key, int) and index is not None: args[key][index] = tensor.inner_values elif isinstance(key, str) and index is None: kwargs[key] = tensor.inner_values else: assert isinstance(key, str) and index is not None kwargs[key][index] = tensor.inner_values # Call the elementwise op on the broadcasted dense values. with ops.control_dependencies(broadcast_check_ops): result_values = op(*args, **kwargs) # Restore any ragged dimensions that we stripped off, and return the # result. return ragged_factory_ops.from_nested_row_splits(result_values, result_splits) # Construct the docstring. op_name = tf_export.get_canonical_name_for_symbol(op) assert op_name is not None, op argnames = ', '.join('`%s`' % s.strip('[]') for s in elementwise_args) docstring = _ELEMENTWISE_DOCSTRING % dict(op_name=op_name, argnames=argnames) # Update name, docstring, signature, etc., for the wrapper, and return it. return tf_decorator.make_decorator(op, ragged_op, decorator_doc=docstring)
def _maybe_find_duplicates(self): """Compute data structures containing information about duplicates. Find duplicates in `index` and decide on one to be the "master" name. Computes a reverse_index mapping each object id to its master name. Also computes a map `duplicate_of` from aliases to their master name (the master name itself has no entry in this map), and a map `duplicates` from master names to a lexicographically sorted list of all aliases for that name (incl. the master name). All these are computed and set as fields if they haven't already. """ if self._reverse_index is not None: return # Maps the id of a symbol to its fully qualified name. For symbols that have # several aliases, this map contains the first one found. # We use id(py_object) to get a hashable value for py_object. Note all # objects in _index are in memory at the same time so this is safe. reverse_index = {} # Make a preliminary duplicates map. For all sets of duplicate names, it # maps the first name found to a list of all duplicate names. raw_duplicates = {} for full_name, py_object in six.iteritems(self._index): # We cannot use the duplicate mechanism for some constants, since e.g., # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants # have no usable docstring and won't be documented automatically. if (py_object is not None and not isinstance(py_object, six.integer_types + six.string_types + (six.binary_type, six.text_type, float, complex, bool)) and py_object is not ()): object_id = id(py_object) if object_id in reverse_index: master_name = reverse_index[object_id] if master_name in raw_duplicates: raw_duplicates[master_name].append(full_name) else: raw_duplicates[master_name] = [master_name, full_name] else: reverse_index[object_id] = full_name # Decide on master names, rewire duplicates and make a duplicate_of map # mapping all non-master duplicates to the master name. The master symbol # does not have an entry in this map. duplicate_of = {} # Duplicates maps the main symbols to the set of all duplicates of that # symbol (incl. itself). duplicates = {} for names in raw_duplicates.values(): names = sorted(names) master_name = ( tf_export.get_canonical_name_for_symbol(self._index[names[0]]) if names else None) if master_name: master_name = 'tf.%s' % master_name else: # Choose the lexicographically first name with the minimum number of # submodules. This will prefer highest level namespace for any symbol. master_name = min(names, key=lambda name: name.count('.')) duplicates[master_name] = names for name in names: if name != master_name: duplicate_of[name] = master_name # Set the reverse index to the canonical name. reverse_index[id(self._index[master_name])] = master_name self._duplicate_of = duplicate_of self._duplicates = duplicates self._reverse_index = reverse_index
class DocGeneratorVisitor(object): """A visitor that generates docs for a python object when __call__ed.""" def __init__(self, root_name=''): """Make a visitor. As this visitor is starting its traversal at a module or class, it will not be told the name of that object during traversal. `root_name` is the name it should use for that object, effectively prefixing all names with "root_name.". Args: root_name: The name of the root module/class. """ self.set_root_name(root_name) self._index = {} self._tree = {} self._reverse_index = None self._duplicates = None self._duplicate_of = None def set_root_name(self, root_name): """Sets the root name for subsequent __call__s.""" self._root_name = root_name or '' self._prefix = (root_name + '.') if root_name else '' @property def index(self): """A map from fully qualified names to objects to be documented. The index is filled when the visitor is passed to `traverse`. Returns: The index filled by traversal. """ return self._index @property def tree(self): """A map from fully qualified names to all its child names for traversal. The full name to member names map is filled when the visitor is passed to `traverse`. Returns: The full name to member name map filled by traversal. """ return self._tree @property def reverse_index(self): """A map from `id(object)` to the preferred fully qualified name. This map only contains non-primitive objects (no numbers or strings) present in `index` (for primitive objects, `id()` doesn't quite do the right thing). It is computed when it, `duplicate_of`, or `duplicates` are first accessed. Returns: The `id(object)` to full name map. """ self._maybe_find_duplicates() return self._reverse_index @property def duplicate_of(self): """A map from duplicate full names to a preferred fully qualified name. This map only contains names that are not themself a preferred name. It is computed when it, `reverse_index`, or `duplicates` are first accessed. Returns: The map from duplicate name to preferred name. """ self._maybe_find_duplicates() return self._duplicate_of @property def duplicates(self): """A map from preferred full names to a list of all names for this symbol. This function returns a map from preferred (master) name for a symbol to a lexicographically sorted list of all aliases for that name (incl. the master name). Symbols without duplicate names do not appear in this map. It is computed when it, `reverse_index`, or `duplicate_of` are first accessed. Returns: The map from master name to list of all duplicate names. """ self._maybe_find_duplicates() return self._duplicates def _add_prefix(self, name): """Adds the root name to a name.""" return self._prefix + name if name else self._root_name def __call__(self, parent_name, parent, children): """Visitor interface, see `tensorflow/tools/common:traverse` for details. This method is called for each symbol found in a traversal using `tensorflow/tools/common:traverse`. It should not be called directly in user code. Args: parent_name: The fully qualified name of a symbol found during traversal. parent: The Python object referenced by `parent_name`. children: A list of `(name, py_object)` pairs enumerating, in alphabetical order, the children (as determined by `tf_inspect.getmembers`) of `parent`. `name` is the local name of `py_object` in `parent`. Raises: RuntimeError: If this visitor is called with a `parent` that is not a class or module. """ parent_name = self._add_prefix(parent_name) self._index[parent_name] = parent self._tree[parent_name] = [] if not (tf_inspect.ismodule(parent) or tf_inspect.isclass(parent)): raise RuntimeError('Unexpected type in visitor -- %s: %r' % (parent_name, parent)) for i, (name, child) in enumerate(list(children)): # Don't document __metaclass__ if name in ['__metaclass__']: del children[i] continue full_name = '.'.join([parent_name, name]) if parent_name else name self._index[full_name] = child self._tree[parent_name].append(name) def _score_name(self, name): """Return a tuple of scores indicating how to sort for the best name. This function is meant to be used as the `key` to the `sorted` function. This sorting in order: Prefers names refering to the defining class, over a subclass. Prefers names that are not in "contrib". prefers submodules to the root namespace. Prefers short names `tf.thing` over `tf.a.b.c.thing` Sorts lexicographically on name parts. Args: name: the full name to score, for example `tf.estimator.Estimator` Returns: A tuple of scores. When sorted the preferred name will have the lowest value. """ parts = name.split('.') short_name = parts[-1] container = self._index['.'.join(parts[:-1])] defining_class_score = 1 if tf_inspect.isclass(container): if short_name in container.__dict__: # prefer the defining class defining_class_score = -1 contrib_score = -1 if 'contrib' in parts: contrib_score = 1 while parts: container = self._index['.'.join(parts)] if tf_inspect.ismodule(container): break parts.pop() module_length = len(parts) if len(parts) == 2: # `tf.submodule.thing` is better than `tf.thing` module_length_score = -1 else: # shorter is better module_length_score = module_length return (defining_class_score, contrib_score, module_length_score, name) def _maybe_find_duplicates(self): """Compute data structures containing information about duplicates. Find duplicates in `index` and decide on one to be the "master" name. Computes a reverse_index mapping each object id to its master name. Also computes a map `duplicate_of` from aliases to their master name (the master name itself has no entry in this map), and a map `duplicates` from master names to a lexicographically sorted list of all aliases for that name (incl. the master name). All these are computed and set as fields if they haven't already. """ if self._reverse_index is not None: return # Maps the id of a symbol to its fully qualified name. For symbols that have # several aliases, this map contains the first one found. # We use id(py_object) to get a hashable value for py_object. Note all # objects in _index are in memory at the same time so this is safe. reverse_index = {} # Make a preliminary duplicates map. For all sets of duplicate names, it # maps the first name found to a list of all duplicate names. raw_duplicates = {} for full_name, py_object in six.iteritems(self._index): # We cannot use the duplicate mechanism for some constants, since e.g., # id(c1) == id(c2) with c1=1, c2=1. This is unproblematic since constants # have no usable docstring and won't be documented automatically. if (py_object not in (None, ()) not isinstance(py_object, six.integer_types + six.string_types + (six.binary_type, six.text_type, float, complex, bool))): object_id = id(py_object) if object_id in reverse_index: master_name = reverse_index[object_id] if master_name in raw_duplicates: raw_duplicates[master_name].append(full_name) else: raw_duplicates[master_name] = [master_name, full_name] else: reverse_index[object_id] = full_name # Decide on master names, rewire duplicates and make a duplicate_of map # mapping all non-master duplicates to the master name. The master symbol # does not have an entry in this map. duplicate_of = {} # Duplicates maps the main symbols to the set of all duplicates of that # symbol (incl. itself). duplicates = {} for names in raw_duplicates.values(): names = sorted(names) master_name = ( tf_export.get_canonical_name_for_symbol(self._index[names[0]]) if names else None) if master_name: master_name = 'tf.%s' % master_name else: # Choose the master name with a lexical sort on the tuples returned by # by _score_name. master_name = min(names, key=self._score_name) duplicates[master_name] = names for name in names: if name != master_name: duplicate_of[name] = master_name # Set the reverse index to the canonical name. reverse_index[id(self._index[master_name])] = master_name self._duplicate_of = duplicate_of self._duplicates = duplicates self._reverse_index = reverse_index