Ejemplo n.º 1
0
def _try_get_state_scope(name, mark_name_scope_used=True):
    """Returns a fresh variable/name scope for a module's state.

  In order to import a module into a given scope without major complications
  we require the scope to be empty. This function deals with deciding an unused
  scope where to define the module state. This is non trivial in cases where
  name_scope and variable_scopes are out of sync, e.g. tpus or re-entering
  scopes.

  Args:
    name: A string with the name of the module as supplied by the client.
    mark_name_scope_used: a boolean, indicating whether to mark the name
        scope of the returned value as used.

  Raises:
    RuntimeError: if the name scope of the freshly created variable scope is
        already used.
  """
    tmp_scope_name = tf_v1.get_variable_scope().name
    if tmp_scope_name:
        tmp_scope_name += "/"
    with tf.name_scope(tmp_scope_name):
        # Pick an unused variable scope.
        with tf_v1.variable_scope(None,
                                  default_name=name,
                                  auxiliary_name_scope=False) as vs:
            abs_state_scope = vs.name + "/"
        # Verify that the name scope is available and mark it used if requested.
        graph = tf_v1.get_default_graph()
        unique_name_scope = graph.unique_name(name, mark_name_scope_used) + "/"
        if unique_name_scope != abs_state_scope:
            raise RuntimeError(
                "variable_scope %s was unused but the corresponding "
                "name_scope was already taken." % abs_state_scope)
    return abs_state_scope
Ejemplo n.º 2
0
 def testModuleInNestedScope(self):
     with tf.Graph().as_default():
         with tf_v1.variable_scope("foo"):
             m = module.Module(_ModuleSpec())
             result = m([1, 2])
         with tf_v1.Session() as session:
             self.assertAllEqual(session.run(result), [2, 4])
Ejemplo n.º 3
0
    def testGetStateScopeWithActiveScopes(self):
        with tf.Graph().as_default():
            with tf_v1.name_scope("foo"):
                abs_scope = module._try_get_state_scope("a", False)
                self.assertEqual(abs_scope, "a/")
                with tf_v1.name_scope(abs_scope) as ns:
                    self.assertEqual(ns, "a/")

        with tf.Graph().as_default():
            with tf_v1.variable_scope("vs"):
                self.assertEqual(module._try_get_state_scope("a", False),
                                 "vs/a/")
                with tf_v1.name_scope(name="a") as ns:
                    self.assertEqual(ns, "vs/a/")

        with tf.Graph().as_default():
            with tf_v1.name_scope("foo"):
                with tf_v1.variable_scope("vs"):
                    self.assertEquals(module._try_get_state_scope("a", False),
                                      "vs/a/")
Ejemplo n.º 4
0
def create_module_spec(module_fn, tags_and_args=None, drop_collections=None):
  """Creates a ModuleSpec from a function that builds the module's graph.

  The `module_fn` is called on a new graph (not the current one) to build the
  graph of the module and define its signatures via `hub.add_signature()`.
  Example:

  ```python
  # Define a text embedding module.
  def my_text_module_fn():
    text_input = tf.placeholder(dtype=tf.string, shape=[None])
    embeddings = compute_embedding(text_input)
    hub.add_signature(inputs=text_input, outputs=embeddings)
  ```

  See `add_signature()` for documentation on adding multiple input/output
  signatures.

  NOTE: In anticipation of future TF-versions, `module_fn` is called on a graph
  that uses resource variables by default. If you want old-style variables then
  you can use `with tf.variable_scope("", use_resource=False)` in `module_fn`.

  Multiple graph variants can be defined by using the `tags_and_args` argument.
  For example, the code:

  ```python
  hub.create_module_spec(
      module_fn,
      tags_and_args=[({"train"}, {"is_training":True}),
                     (set(), {"is_training":False})])
  ```

  calls `module_fn` twice, once as `module_fn(is_training=True)` and once as
  `module_fn(is_training=False)` to define the respective graph variants:
  for training with tags {"train"} and for inference with the empty set of tags.
  Using the empty set aligns the inference case with the default in
  Module.__init__().

  Args:
    module_fn: a function to build a graph for the Module.
    tags_and_args: Optional list of tuples (tags, kwargs) of tags and keyword
      args used to define graph variants. If omitted, it is interpreted as
      [(set(), {})], meaning `module_fn` is called once with no args.
    drop_collections: list of collection to drop.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: if it fails to construct the ModuleSpec due to bad or
      unsupported values in the arguments or in the graphs constructed by
      `module_fn`.
  """
  if not drop_collections:
    drop_collections = []

  report_tags = True
  if not tags_and_args:
    tags_and_args = [(set(), {})]
    report_tags = False

  saved_model_handler = saved_model_lib.SavedModelHandler()
  for tags, args in tags_and_args:
    with tf.Graph().as_default() as graph:
      with tf_v1.variable_scope("", use_resource=True):
        module_fn(**args)

      for collection_key in drop_collections:
        del tf_v1.get_collection_ref(collection_key)[:]

    err = find_state_op_colocation_error(graph, tags if report_tags else None)
    if err: raise ValueError(err)
    saved_model_handler.add_graph_copy(graph, tags=tags)

  return _ModuleSpec(saved_model_handler, checkpoint_variables_path=None)
Ejemplo n.º 5
0
def eval_function_for_module(spec, tags=None):
    """Context manager that yields a function to directly evaluate a Module.

  This creates a separate graph, in which all of the signatures of the module
  are instantiated. Then, it creates a session and initializes the module
  variables. Finally, it returns a function which can be used to evaluate the
  module signatures.

  The function returned by eval_function_for_module has the same syntax as
  Module.__call__ , except that inputs and outputs are not tensors but actual
  values as used with Session.run().

  ```python
  with hub.eval_function_for_module("/tmp/text-embedding") as f:
    # The module can be directly evaluated using f without constructing a graph.
    embeddings = f(["Hello world!",], signature="mysignature")
  ```

  Args:
    spec: A ModuleSpec defining the Module to instantiate or a path where to
      load a ModuleSpec from via `load_module_spec`.
    tags: A set of strings specifying the graph variant to use.

  Yields:
    A function whose keyword arguments are fed into the tfhub module and which
      returns a dictionary with the value of the output tensors.

  Raises:
    RuntimeError: explaning the reason why it failed to instantiate the
      Module.
    ValueError: if the requested graph variant does not exists.
  """
    # We create a separate graph and add all the signatures of the module to it.
    original_graph = tf_v1.get_default_graph()
    with tf.Graph().as_default():
        module = Module(spec, tags=tags)
        input_tensors_per_signature = {}
        output_tensors_per_signature = {}
        for signature in module.get_signature_names():
            # We scope with the signature name as different signatures will likely
            # contain tensors with the same name (e.g. the input and output tensors).
            with tf_v1.variable_scope(signature):
                input_tensors = {}
                for name, tensorinfo in module.get_input_info_dict(
                        signature).items():
                    # We need to be care with the shape as it may be fully-known,
                    # partially-known or even unknown.
                    shape = tensorinfo.get_shape()
                    effective_shape = None if shape.dims is None else shape.as_list(
                    )
                    if tensorinfo.is_sparse:
                        input_tensors[name] = tf_v1.sparse_placeholder(
                            tensorinfo.dtype, shape=effective_shape, name=name)
                    else:
                        input_tensors[name] = tf_v1.placeholder(
                            tensorinfo.dtype, shape=effective_shape, name=name)
                input_tensors_per_signature[signature] = input_tensors
                output_tensors_per_signature[signature] = module(
                    input_tensors_per_signature[signature],
                    signature=signature,
                    as_dict=True)

    # Evaluating the tfhub module requires an active tensorflow session.
        with tf_v1.train.SingularMonitoredSession() as sess:

            def func(
                    inputs=None,
                    _sentinel=None,  # pylint: disable=invalid-name
                    signature=None,
                    as_dict=None):
                """Function that directly evaluates a signature in the module."""
                signature = signature or "default"
                input_tensors = input_tensors_per_signature[signature]

                dict_inputs = _prepare_dict_inputs(inputs, input_tensors)

                # The input arguments are directly fed into the session.
                feed_dict = {
                    input_tensors[key]: value
                    for key, value in dict_inputs.items()
                }
                output = output_tensors_per_signature[signature]
                output = _prepare_outputs(output, as_dict)
                return sess.run(output, feed_dict=feed_dict)

            with original_graph.as_default():
                # Yield the function since that will keep the session alive until the
                # user exits the context.
                yield func
Ejemplo n.º 6
0
 def __init__(self, name, trainable):
     super(_ModuleImpl, self).__init__()
     with tf_v1.variable_scope(name):
         pass
Ejemplo n.º 7
0
 def testGetStateScope_UsesVariableScope(self):
     with tf.Graph().as_default():
         self.assertEqual(module._try_get_state_scope("a"), "a/")
         with tf_v1.variable_scope(None, default_name="a") as vs:
             self.assertEqual(vs.name, "a_1")