def _try_get_state_scope(name, mark_name_scope_used=True): """Returns a fresh variable/name scope for a module's state. In order to import a module into a given scope without major complications we require the scope to be empty. This function deals with deciding an unused scope where to define the module state. This is non trivial in cases where name_scope and variable_scopes are out of sync, e.g. tpus or re-entering scopes. Args: name: A string with the name of the module as supplied by the client. mark_name_scope_used: a boolean, indicating whether to mark the name scope of the returned value as used. Raises: RuntimeError: if the name scope of the freshly created variable scope is already used. """ tmp_scope_name = tf_v1.get_variable_scope().name if tmp_scope_name: tmp_scope_name += "/" with tf.name_scope(tmp_scope_name): # Pick an unused variable scope. with tf_v1.variable_scope(None, default_name=name, auxiliary_name_scope=False) as vs: abs_state_scope = vs.name + "/" # Verify that the name scope is available and mark it used if requested. graph = tf_v1.get_default_graph() unique_name_scope = graph.unique_name(name, mark_name_scope_used) + "/" if unique_name_scope != abs_state_scope: raise RuntimeError( "variable_scope %s was unused but the corresponding " "name_scope was already taken." % abs_state_scope) return abs_state_scope
def __init__(self, spec, trainable=False, name="module", tags=None): """Constructs a Module to be used in the current graph. This creates the module `state-graph` under an unused variable_scope based on `name`. During this call a Module will: - Add GLOBAL_VARIABLES under its scope. Those variables may be added to to the TRAINABLE_VARIABLES collection (depending on `trainable` parameter) and to the MODEL_VARIABLES. The variables must be initialized before use, and can be checkpointed as usual. - Add ops to the INIT_TABLE_OPS collection, which must be run during session initialization and add constant tensors to ASSET_FILEPATHS that are needed during the execution of such ops. - Add tensors to the REGULARIZATION_LOSSES collection (depending on `trainable` parameter). Args: spec: A ModuleSpec defining the Module to instantiate or a path where to load a ModuleSpec from via `load_module_spec`. trainable: whether the Module is trainable. If False, no variables are added to TRAINABLE_VARIABLES collection, and no tensors are added to REGULARIZATION_LOSSES collection. name: A string, the variable scope name under which to create the Module. It will be uniquified and the equivalent name scope must be unused. tags: A set of strings specifying the graph variant to use. Raises: RuntimeError: explaning the reason why it failed to instantiate the Module. ValueError: if the requested graph variant does not exists. tf.errors.NotFoundError: if the requested graph contains unknown ops. """ self._graph = tf_v1.get_default_graph() self._spec = as_module_spec(spec) self._trainable = trainable self._tags = set(tags or []) if self._tags not in self._spec.get_tags(): tags = sorted(list(tags)) if tags else tags raise ValueError("No such graph variant: tags=%r" % tags) abs_state_scope = _try_get_state_scope(name, mark_name_scope_used=False) self._name = abs_state_scope.split("/")[-2] abs_parent_scope = abs_state_scope.split("/")[:-2] if abs_parent_scope: abs_parent_scope = "/".join(abs_parent_scope) + "/" else: abs_parent_scope = "" with tf.name_scope(abs_parent_scope): # pylint: disable=protected-access self._impl = self._spec._create_impl(name=self._name, trainable=self._trainable, tags=self._tags)
def _build_colocation_attr_map(input_map, absolute_import_scope): """Returns a dict mapping from pre-import to post-import colocation attrs. Args: input_map: as for fix_colocation_after_import. absolute_import_scope: as for fix_colocation_after_import. Returns: A dict that maps bytes `"loc:@" + absolute_import_scope + "/foo"` to _ConsistentValues set to the lists of bytes `["loc:@...", ...]` according to the rewriting scheme of fix_colocation_after_import. In case of an inconsistent rewriting, _ConsistentValue.has_error is true. """ colocation_attr_map = collections.defaultdict(_ConsistentValue) used_outputs_of_imported_ops = collections.defaultdict(set) # Collect mappings from the input_map. for imported_tensor_name, mapped_tensor in input_map.items(): imported_tensor_name = absolute_import_scope + "/" + imported_tensor_name imported_op_name, imported_index = _split_tensor_name( imported_tensor_name) key = tf.compat.as_bytes("loc:@" + imported_op_name) colocation_attr_map[key].Set( mapped_tensor.op.colocation_groups(), { "reason": "input '%s' is substituted by '%s'" % (imported_tensor_name, mapped_tensor.name) }) used_outputs_of_imported_ops[imported_op_name].add(imported_index) # Add unchanged mappings for additional, non-remapped outputs of ops touched # by the input_map. For now, these just signal inconsistency when used. for imported_op_name, used_outputs in used_outputs_of_imported_ops.items(): imported_op = tf_v1.get_default_graph().get_operation_by_name( imported_op_name) unused_outputs = set(range(len(imported_op.outputs))) - used_outputs if not unused_outputs: continue key = tf.compat.as_bytes("loc:@" + imported_op_name) if imported_op.colocation_groups() != [key]: # This should never happen: state nodes are remapped fully, input nodes # are prevented from having colocation attributes. raise ValueError( "Internal error: tensors from op '%s' are partially remapped in " "import but op.colocation_groups=%s cannot be captured in a " "simple rewrite rule." % (imported_op_name, imported_op.colocation_groups())) colocation_attr_map[key].Set( [key], { "reason": "tensor '%s:%s' is not substituted by inputs" % (imported_op_name, ",".join( str(i) for i in sorted(unused_outputs))) }) return colocation_attr_map
def export(self, path, session): """Exports the module with the variables from the session in `path`. Note that it is the module definition in the ModuleSpec used to create this module that gets exported. The session is only used to provide the value of variables. Args: path: path where to export the module to. session: session where to export the variables from. Raises: RuntimeError: if there is an issue during the export. """ if self._graph is not tf_v1.get_default_graph(): raise RuntimeError("default graph differs from the graph where the " "module was instantiated.") if self._graph is not session.graph: raise RuntimeError("session graph differs from the graph where the " "module was instantiated.") self._impl.export(path, session)
def create_apply_graph(self, signature, input_tensors, name): """See `ModuleImpl.create_apply_graph`.""" signature_def = self._meta_graph.signature_def.get(signature) meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.CopyFrom(self._meta_graph) apply_graph = tf_v1.get_default_graph() infeed_map = tensor_info.build_input_map(signature_def.inputs, input_tensors) # Build a input map to feed when importing the apply-graph by augmenting the # state_map with the input args. This allows an input to override a tensor # from the state-graph. feed_map = dict(self._state_map) # If we are applying the module in a function with a TPUReplicateContext, we # must capture the state tensors in generating our feedmap and prune out # assign ops. Function graph semantics are different in that all ops are # executed regardless of dependency. # TODO(b/112575006): The following adds functionality of function call # within a TPU context. Work to generalize this for all function calls is # ongoing. if _is_tpu_graph_function(): for k, v in self._state_map.items(): feed_map[k] = apply_graph.capture(v) meta_graph_lib.prune_unused_nodes(meta_graph, signature_def) # After we prune the metagraph def, we might need to prune away # infeeds which no longer exist. meta_graph_lib.prune_feed_map(meta_graph, infeed_map) elif apply_graph.building_function: # Log a warning if a user is using a hub module in function graph. # This is only expected to work if the function graph is pruned and # not all nodes are executed. # # E.g. it could work with "tf.compat.v1.wrap_function", but it will not # work with defun, Dataset.map_fn, etc... logging.warning("Using `hub.Module` while building a function: %s. This " "can lead to errors if the function is not pruned.", apply_graph.name) # As state ops in the apply graph are unused, replace them with Placeholders # so that in a heirarchical instantiation, apply_graph state ops are # ignored. replace_apply_state( meta_graph, list_registered_stateful_ops_without_inputs(meta_graph.graph_def), feed_map) feed_map.update(infeed_map) # Make state tensors enter the current context. This way the Module can be # applied inside a control flow structure such as a while_loop. control_flow = apply_graph._get_control_flow_context() # pylint: disable=protected-access if control_flow: for key, value in sorted(feed_map.items()): feed_map[key] = control_flow.AddValue(value) # Don't mark the name as used at this point - import_scoped_meta_graph will # start using it. absolute_scope_name = apply_graph.unique_name(name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] import_collections = [ # In most cases ASSET_FILEPATHS are only used for the TABLE_INITIALIZERS # ops, however one could create a graph that uses an asset at any other # time. As so everytime we bring the tensor with that has the asset # filename we must annotate it as so, so later re-exports have that # semantic information and can handle it. tf_v1.GraphKeys.ASSET_FILEPATHS, tf_v1.GraphKeys.COND_CONTEXT, tf_v1.GraphKeys.WHILE_CONTEXT, ] if self._trainable: import_collections.extend([tf_v1.GraphKeys.UPDATE_OPS]) meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) if len(meta_graph.collection_def) and _is_tpu_graph_function(): raise NotImplementedError( "Applying modules with collections inside TPU functions is not " "supported. Collections found: %s" % str(meta_graph.collection_def)) tf_v1.train.import_meta_graph( meta_graph, input_map=feed_map, import_scope=relative_scope_name) fix_colocation_after_import(input_map=feed_map, absolute_import_scope=absolute_scope_name) def get_tensor(name): # When trying to output an input tensor there are no nodes created within # the apply scope. So one must look into the input map. try: return feed_map[name] except KeyError: return apply_graph.get_tensor_by_name( meta_graph_lib.prepend_name_scope( name, import_scope=absolute_scope_name)) return tensor_info.build_output_map(signature_def.outputs, get_tensor)
def _get_tensor(tensor_name): return tf_v1.get_default_graph().get_tensor_by_name( meta_graph_lib.prepend_name_scope( tensor_name, import_scope=absolute_scope_name))
def _create_state_graph(self, name): """Creates the graph nodes that hold the state of the Module. Args: name: name scope to create the state graph in. Returns: A tuple consisting of: variables_tensor_map: a map from tensor names in the original graph def to the created Variables objects. state_map: a map from tensors names in the original graph def to the instantiated tensors to be used as a state_map. """ import_collections = [ tf_v1.GraphKeys.GLOBAL_VARIABLES, tf_v1.GraphKeys.MODEL_VARIABLES, tf_v1.GraphKeys.TABLE_INITIALIZERS, tf_v1.GraphKeys.ASSET_FILEPATHS, # Typically used to initialize tables. tf_v1.GraphKeys.COND_CONTEXT, tf_v1.GraphKeys.WHILE_CONTEXT, ] if self._trainable: # TODO(b/64049014): Import UPDATE_OPS which do not depend on inputs. import_collections.extend([tf_v1.GraphKeys.TRAINABLE_VARIABLES, tf_v1.GraphKeys.REGULARIZATION_LOSSES]) absolute_scope_name = tf_v1.get_default_graph().unique_name( name, mark_as_used=False) relative_scope_name = absolute_scope_name.split("/")[-1] assert relative_scope_name == name # verify name scope was indeed unused. meta_graph = meta_graph_pb2.MetaGraphDef() meta_graph.CopyFrom(self._meta_graph) meta_graph_lib.filter_collections(meta_graph, import_collections) meta_graph_lib.prefix_shared_name_attributes(meta_graph, absolute_scope_name) tf_v1.train.import_meta_graph( meta_graph, input_map={}, import_scope=relative_scope_name) # Build a list from the variable name in the module definition to the actual # instantiated variables. variables_tensor_map = {} for var in tf_v1.global_variables(): if var.op.name.startswith(absolute_scope_name + "/"): variables_tensor_map[var.name[len(absolute_scope_name)+1:]] = var # Build a map of tensors to feed from the state-graph into subsequent # apply-graphs. def _get_tensor(tensor_name): return tf_v1.get_default_graph().get_tensor_by_name( meta_graph_lib.prepend_name_scope( tensor_name, import_scope=absolute_scope_name)) state_op_names = list_registered_stateful_ops_without_inputs( meta_graph.graph_def) state_map = get_state_map(meta_graph, state_op_names, set(), _get_tensor) return variables_tensor_map, state_map
def _is_tpu_graph_function(): graph = tf_v1.get_default_graph() return (graph.building_function and type(graph._get_control_flow_context()).__name__.endswith( # pylint: disable=protected-access "TPUReplicateContext"))
def _apply_colocation_attr_map(colocation_attr_map, absolute_import_scope): """Rewrites colocation constraints in the current default graph. Nodes in `absolute_import_scope` get their "_class" attr lists rewritten according to `colocation_attr_map`: each entry that matches a key gets replaced by the associated values (with deduplication). The node's device is updated accordingly. Args: colocation_attr_map: as returned by _build_colocation_attr_map. absolute_import_scope: as for fix_colocation_after_import. Raises: ValueError: if rewriting runs into an inconsistent value in `colocation_attr_map`. """ graph = tf_v1.get_default_graph() for op in graph.get_operations(): # Rewrite the values of the "_class" attr that store colocation constraints. # NOTE: The colocation_group loc:@X of a node with itself is not stored # explicitly as an attr, so rewrite errors for loc:@X are not triggered # by the mere existence of X. if not op.name.startswith(absolute_import_scope + "/"): continue try: class_values = op.get_attr("_class") except ValueError: continue # No _class attr found; nothing to do. new_attr_value = tf_v1.AttrValue() new_coloc_groups = [] for class_value in class_values: if class_value.startswith(tf.compat.as_bytes("loc:@")): if class_value not in colocation_attr_map: rewritten_class_value = [class_value] else: rewritten_class_value = (colocation_attr_map[ class_value].GetConsistentValueOrRaise( "Failed to rewrite colocation constraints while applying " "hub.Module:\n" "The module graph contains a node {op!r} " "that has a colocation constraint {class_value!r} " "with ambiguous rewriting {old_value!r} vs {new_value!r} " "because {old_reason} and {new_reason}, respectively.\n" "To fix, avoid publishing a module with inputs comprising " "multiple outputs of one op that is referenced in " "tf.colocate_with(...) constraints on other ops.", {"op": op.name, "class_value": class_value})) new_coloc_groups.extend(rewritten_class_value) else: new_attr_value.list.s.append(class_value) new_coloc_groups = sorted(set(new_coloc_groups)) new_attr_value.list.s.extend(new_coloc_groups) op._set_attr("_class", new_attr_value) # pylint: disable=protected-access # Mimic the code of tf.import_graph_def(): If there are colocation # constraints, use any of them to set the device (overriding what the # device function stack would do), without attempting to merge or check for # equality. If they were inconsistent, TensorFlow's C++ runtime would fail # anyways due to conflicting colocation constraints. # Note that Hub imports GraphDefs with devices cleared, so this code deals # with the result of import_graph_def, not a setting saved in the module. if new_coloc_groups: new_coloc_device = "" for new_coloc_group in new_coloc_groups: assert new_coloc_group.startswith(tf.compat.as_bytes("loc:@")) new_coloc_target_op = graph.get_operation_by_name( tf.compat.as_str_any(new_coloc_group[5:])) new_coloc_device = new_coloc_target_op.device if new_coloc_device: break # Set this, even if empty, to avoid retaining an outdated value. op._set_device(new_coloc_device) # pylint: disable=protected-access
def eval_function_for_module(spec, tags=None): """Context manager that yields a function to directly evaluate a Module. This creates a separate graph, in which all of the signatures of the module are instantiated. Then, it creates a session and initializes the module variables. Finally, it returns a function which can be used to evaluate the module signatures. The function returned by eval_function_for_module has the same syntax as Module.__call__ , except that inputs and outputs are not tensors but actual values as used with Session.run(). ```python with hub.eval_function_for_module("/tmp/text-embedding") as f: # The module can be directly evaluated using f without constructing a graph. embeddings = f(["Hello world!",], signature="mysignature") ``` Args: spec: A ModuleSpec defining the Module to instantiate or a path where to load a ModuleSpec from via `load_module_spec`. tags: A set of strings specifying the graph variant to use. Yields: A function whose keyword arguments are fed into the tfhub module and which returns a dictionary with the value of the output tensors. Raises: RuntimeError: explaning the reason why it failed to instantiate the Module. ValueError: if the requested graph variant does not exists. """ # We create a separate graph and add all the signatures of the module to it. original_graph = tf_v1.get_default_graph() with tf.Graph().as_default(): module = Module(spec, tags=tags) input_tensors_per_signature = {} output_tensors_per_signature = {} for signature in module.get_signature_names(): # We scope with the signature name as different signatures will likely # contain tensors with the same name (e.g. the input and output tensors). with tf_v1.variable_scope(signature): input_tensors = {} for name, tensorinfo in module.get_input_info_dict( signature).items(): # We need to be care with the shape as it may be fully-known, # partially-known or even unknown. shape = tensorinfo.get_shape() effective_shape = None if shape.dims is None else shape.as_list( ) if tensorinfo.is_sparse: input_tensors[name] = tf_v1.sparse_placeholder( tensorinfo.dtype, shape=effective_shape, name=name) else: input_tensors[name] = tf_v1.placeholder( tensorinfo.dtype, shape=effective_shape, name=name) input_tensors_per_signature[signature] = input_tensors output_tensors_per_signature[signature] = module( input_tensors_per_signature[signature], signature=signature, as_dict=True) # Evaluating the tfhub module requires an active tensorflow session. with tf_v1.train.SingularMonitoredSession() as sess: def func( inputs=None, _sentinel=None, # pylint: disable=invalid-name signature=None, as_dict=None): """Function that directly evaluates a signature in the module.""" signature = signature or "default" input_tensors = input_tensors_per_signature[signature] dict_inputs = _prepare_dict_inputs(inputs, input_tensors) # The input arguments are directly fed into the session. feed_dict = { input_tensors[key]: value for key, value in dict_inputs.items() } output = output_tensors_per_signature[signature] output = _prepare_outputs(output, as_dict) return sess.run(output, feed_dict=feed_dict) with original_graph.as_default(): # Yield the function since that will keep the session alive until the # user exits the context. yield func
def __call__( self, inputs=None, # pylint: disable=invalid-name _sentinel=None, signature=None, as_dict=None): """Instantiates a module signature in the graph. Example calls: ```python # Use default signature with one input and default output. embeddings = m(["hello world", "good morning"]) # Use "encode" signature with one input and default output. encodings = m(["hello world"], signature="encode") # Use default signature with input dict and output dict. dict_outputs = m({"text": [...], "lang": [...]}, as_dict=True) ``` The method __call__() allows to create the graph ops that compute a signature outputs given the inputs and using this module instance state. Each signature can be applied multiple times with different inputs and they all share the same module state. A Module may define multiple signatures. Use `signature=<name>` to identify the specific signature to instantiate. If omitted or None, the default signature is used. A signature may define various outputs. Use `as_dict=True` to return a dict of all outputs. If omitted or False, the output named 'default' is returned. During this call a Module will: - Add ops in the current name scope to convert the inputs in tensors to feed to the signature. - Add ops to the UPDATE_OPS collection which depend on at least one of the provided inputs if the Module was constructed with `trainable=True`. - Add constant tensors to ASSET_FILEPATHS, even if those are not needed directly needed for the signature. Note: `hub.Module` implementation depends on graph pruning that happens usually during `session.run` as so it can lead to errors when used inside function graphs that execute all its ops (e.g. `tf.data.Dataset.map`). Args: inputs: Inputs to the signature. A dict from input names to tensor values. If the signature only expects one input, one may pass a single value. If the signature has no inputs, it may be omitted. _sentinel: Used to prevent positional parameters besides `inputs`. signature: A string with the signature name to apply. If none, the default signature is used. as_dict: A boolean indicating whether to the return all the outputs of the signature as a dict or return only the default output. Returns: A tensor (single or sparse) if the signature defines a default output or a dict from strings (output names) to tensors if `as_dict=True` is used. Raises: TypeError: If there is a mismatch on arguments, inputs or outputs of the module signature. RuntimeError: If there are errors during creation of the signature graph. """ if self._graph is not tf_v1.get_default_graph(): raise RuntimeError( "Module must be applied in the graph it was instantiated for.") signature = self._impl.get_signature_name(signature) # SavedModel non-default signatures automatically includes ':' in them, # but that is an invalid character for a name that is used as part # of variable scopes. safe_signature = signature.replace(":", "_") name = "%s_apply_%s" % (self._name, safe_signature) dict_inputs = _convert_dict_inputs( inputs, self._spec.get_input_info_dict(signature=signature, tags=self._tags)) dict_outputs = self._impl.create_apply_graph(signature=signature, input_tensors=dict_inputs, name=name) return _prepare_outputs(dict_outputs, as_dict=as_dict)
def _get_tensor(name): return tf_v1.get_default_graph().get_tensor_by_name(name)