예제 #1
0
    def testAssets(self):
        original_asset_file = os.path.join(self.get_temp_dir(), "hello.txt")
        _write_string_to_file(original_asset_file, "hello world")

        with tf.Graph().as_default() as graph:
            asset_tensor = tf.constant(original_asset_file, name="file")
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS,
                                    asset_tensor)
            saved_model_lib.add_signature("default", {},
                                          {"default": asset_tensor})

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)

        export_dir = os.path.join(self.get_temp_dir(), "exported")
        handler.export(export_dir)

        # Check that asset file got written to the expected place:
        exported_asset_file = os.path.join(export_dir, "assets", "hello.txt")
        self.assertTrue(tf_v1.gfile.Exists(exported_asset_file))

        loaded_handler = saved_model_lib.load(export_dir)
        with _instantiate_meta_graph(loaded_handler).as_default():
            with tf_v1.Session() as sess:
                self.assertEqual(sess.run("file:0"),
                                 tf.compat.as_bytes(exported_asset_file))
예제 #2
0
    def testWithMultipleAssetsWithSameBasename(self):
        tmp_asset_dir = os.path.join(self.get_temp_dir(), "asset")
        file_a = os.path.join(tmp_asset_dir, "a", "hello.txt")
        file_b = os.path.join(tmp_asset_dir, "b", "hello.txt")
        tf_v1.gfile.MakeDirs(os.path.dirname(file_a))
        tf_v1.gfile.MakeDirs(os.path.dirname(file_b))
        _write_string_to_file(file_a, "hello A")
        _write_string_to_file(file_b, "hello B")
        with tf.Graph().as_default() as graph:
            asset_a = tf.constant(file_a, name="file_a")
            asset_b = tf.constant(file_b, name="file_b")
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_a)
            graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_b)
            saved_model_lib.add_signature("default", {}, {"default": asset_a})

        export_dir = os.path.join(self.get_temp_dir(), "exported")
        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        handler.export(export_dir)
        tf_v1.gfile.DeleteRecursively(tmp_asset_dir)

        loaded_handler = saved_model_lib.load(export_dir)
        with _instantiate_meta_graph(loaded_handler).as_default():
            with tf_v1.Session() as sess:
                self.assertEqual(_read_file_to_string(sess.run("file_a:0")),
                                 "hello A")
                self.assertEqual(_read_file_to_string(sess.run("file_b:0")),
                                 "hello B")
예제 #3
0
    def testCreationOfAssetsKeyCollectionIsDeterministic(self):
        tmp_asset_dir = os.path.join(self.get_temp_dir(), "assets")
        tf_v1.gfile.MakeDirs(tmp_asset_dir)
        filenames = [
            os.path.join(tmp_asset_dir, "file%d.txt" % n) for n in range(10)
        ]
        for filename in filenames:
            _write_string_to_file(filename, "I am file %s" % filename)

        with tf.Graph().as_default() as graph:
            assets = [
                tf.constant(f, name=os.path.basename(f)) for f in filenames
            ]
            for asset in assets:
                graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset)
            saved_model_lib.add_signature("default", {},
                                          {"default": assets[0]})

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        saved_model_proto = copy.deepcopy(handler._proto)
        export_dir = os.path.join(self.get_temp_dir(), "assets_key_test")
        saved_model_lib._make_assets_key_collection(saved_model_proto,
                                                    export_dir)

        meta_graph = list(saved_model_proto.meta_graphs)[0]
        asset_tensor_names = []
        for asset_any_proto in meta_graph.collection_def[
                tf_v1.saved_model.constants.ASSETS_KEY].any_list.value:
            asset_proto = meta_graph_pb2.AssetFileDef()
            asset_any_proto.Unpack(asset_proto)
            asset_tensor_names.append(asset_proto.tensor_info.name)
        self.assertEqual(asset_tensor_names, sorted(asset_tensor_names))
예제 #4
0
    def testSignatureImplementationIsInvisible(self):
        with tf.Graph().as_default() as graph:
            saved_model_lib.add_signature("test", {}, {})
            self.assertEqual(graph.get_all_collection_keys(), [])

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        meta_graph, = handler.meta_graphs
        self.assertEqual(len(meta_graph.collection_def), 0)
        self.assertEqual(len(meta_graph.signature_def), 1)
예제 #5
0
    def testEmptyCollectionsDoNotShowUpInMetaGraphDef(self):
        with tf.Graph().as_default() as graph:
            tf.Variable("name")
            self.assertEqual(len(graph.get_all_collection_keys()), 2)
            for collection_key in graph.get_all_collection_keys():
                del graph.get_collection_ref(collection_key)[:]

        handler = saved_model_lib.SavedModelHandler()
        handler.add_graph_copy(graph)
        meta_graph, = handler.meta_graphs
        self.assertEqual(len(meta_graph.collection_def), 0)
예제 #6
0
 def testTags(self):
   with tf.Graph().as_default() as graph:
     saved_model_lib.add_signature("default", {}, {"default": tf.constant(1)})
   handler = saved_model_lib.SavedModelHandler()
   handler.add_graph_copy(graph, ["tag1"])
   handler.add_graph_copy(graph, ["tag1", "tag2"])
   self.assertAllEqual(sorted(handler.get_tags()),
                       sorted([set(["tag1"]), set(["tag1", "tag2"])]))
   self.assertIsNotNone(handler.get_meta_graph_copy(["tag1"]))
   self.assertIsNotNone(handler.get_meta_graph_copy(["tag2", "tag1"]))
   with self.assertRaises(KeyError):
     handler.get_meta_graph_copy(["tag2"])
예제 #7
0
 def testTags(self):
     graph = tf.Graph()
     handler = saved_model_lib.SavedModelHandler()
     handler.add_graph_copy(graph, ["tag1"])
     handler.add_graph_copy(graph, ["tag1", "tag2"])
     self.assertAllEqual(sorted(handler.get_tags()),
                         sorted([set(["tag1"]),
                                 set(["tag1", "tag2"])]))
     self.assertTrue(handler.get_meta_graph_copy(["tag1"]) is not None)
     self.assertTrue(
         handler.get_meta_graph_copy(["tag2", "tag1"]) is not None)
     with self.assertRaises(KeyError):
         handler.get_meta_graph_copy(["tag2"])
예제 #8
0
  def testSignatures(self):
    with tf.Graph().as_default() as graph:
      input_a = tf.constant(2)
      input_b = tf.constant(3)
      mul = input_a * input_b
      saved_model_lib.add_signature("six", {}, {"out": mul})
      saved_model_lib.add_signature("mul2", {"in": input_b}, {"out": mul})

    handler = saved_model_lib.SavedModelHandler()
    handler.add_graph_copy(graph)

    signatures = handler.get_meta_graph_copy().signature_def
    self.assertEqual(set(signatures.keys()), set(["six", "mul2"]))
    self.assertAllEqual(list(signatures["six"].inputs.keys()), [])
    self.assertAllEqual(list(signatures["six"].outputs.keys()), ["out"])
    self.assertAllEqual(list(signatures["mul2"].inputs.keys()), ["in"])
    self.assertAllEqual(list(signatures["mul2"].outputs.keys()), ["out"])
예제 #9
0
def create_module_spec(module_fn, tags_and_args=None, drop_collections=None):
    """Creates a ModuleSpec from a function that builds the module's graph.

  The `module_fn` is called on a new graph (not the current one) to build the
  graph of the module and define its signatures via `hub.add_signature()`.
  Example:

  ```python
  # Define a text embedding module.
  def my_text_module_fn():
    text_input = tf.placeholder(dtype=tf.string, shape=[None])
    embeddings = compute_embedding(text)
    hub.add_signature(inputs=text_input, outputs=embeddings)
  ```

  See `add_signature()` for documentation on adding multiple input/output
  signatures.

  NOTE: In anticipation of future TF-versions, `module_fn` is called on a graph
  that uses resource variables by default. If you want old-style variables then
  you can use `with tf.variable_scope("", use_resource=False)` in `module_fn`.

  Multiple graph variants can be defined by using the `tags_and_args` argument.
  For example, the code:

  ```python
  hub.create_module_spec(
      module_fn,
      tags_and_args=[({"train"}, {"is_training":True}),
                     (set(), {"is_training":False})])
  ```

  calls `module_fn` twice, once as `module_fn(is_training=True)` and once as
  `module_fn(is_training=False)` to define the respective graph variants:
  for training with tags {"train"} and for inference with the empty set of tags.
  Using the empty set aligns the inference case with the default in
  Module.__init__().

  Args:
    module_fn: a function to build a graph for the Module.
    tags_and_args: Optional list of tuples (tags, kwargs) of tags and keyword
      args used to define graph variants. If omitted, it is interpreted as
      [set(), {}], meaning `module_fn` is called once with no args.
    drop_collections: list of collection to drop.

  Returns:
    A ModuleSpec.

  Raises:
    ValueError: if it fails to construct the ModuleSpec due to bad or
      unsupported values in the arguments or in the graphs constructed by
      `module_fn`.
  """
    if not drop_collections:
        drop_collections = []

    report_tags = True
    if not tags_and_args:
        tags_and_args = [(set(), {})]
        report_tags = False

    saved_model_handler = saved_model_lib.SavedModelHandler()
    for tags, args in tags_and_args:
        with tf.Graph().as_default() as graph:
            with tf.variable_scope("", use_resource=True):
                module_fn(**args)

            for collection_key in drop_collections:
                del tf.get_collection_ref(collection_key)[:]

        err = find_state_op_colocation_error(graph,
                                             tags if report_tags else None)
        if err: raise ValueError(err)
        saved_model_handler.add_graph_copy(graph, tags=tags)

    return _ModuleSpec(saved_model_handler, checkpoint_variables_path=None)