def testAssets(self): original_asset_file = os.path.join(self.get_temp_dir(), "hello.txt") _write_string_to_file(original_asset_file, "hello world") with tf.Graph().as_default() as graph: asset_tensor = tf.constant(original_asset_file, name="file") graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_tensor) saved_model_lib.add_signature("default", {}, {"default": asset_tensor}) handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) export_dir = os.path.join(self.get_temp_dir(), "exported") handler.export(export_dir) # Check that asset file got written to the expected place: exported_asset_file = os.path.join(export_dir, "assets", "hello.txt") self.assertTrue(tf_v1.gfile.Exists(exported_asset_file)) loaded_handler = saved_model_lib.load(export_dir) with _instantiate_meta_graph(loaded_handler).as_default(): with tf_v1.Session() as sess: self.assertEqual(sess.run("file:0"), tf.compat.as_bytes(exported_asset_file))
def testWithMultipleAssetsWithSameBasename(self): tmp_asset_dir = os.path.join(self.get_temp_dir(), "asset") file_a = os.path.join(tmp_asset_dir, "a", "hello.txt") file_b = os.path.join(tmp_asset_dir, "b", "hello.txt") tf_v1.gfile.MakeDirs(os.path.dirname(file_a)) tf_v1.gfile.MakeDirs(os.path.dirname(file_b)) _write_string_to_file(file_a, "hello A") _write_string_to_file(file_b, "hello B") with tf.Graph().as_default() as graph: asset_a = tf.constant(file_a, name="file_a") asset_b = tf.constant(file_b, name="file_b") graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_a) graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset_b) saved_model_lib.add_signature("default", {}, {"default": asset_a}) export_dir = os.path.join(self.get_temp_dir(), "exported") handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) handler.export(export_dir) tf_v1.gfile.DeleteRecursively(tmp_asset_dir) loaded_handler = saved_model_lib.load(export_dir) with _instantiate_meta_graph(loaded_handler).as_default(): with tf_v1.Session() as sess: self.assertEqual(_read_file_to_string(sess.run("file_a:0")), "hello A") self.assertEqual(_read_file_to_string(sess.run("file_b:0")), "hello B")
def testCreationOfAssetsKeyCollectionIsDeterministic(self): tmp_asset_dir = os.path.join(self.get_temp_dir(), "assets") tf_v1.gfile.MakeDirs(tmp_asset_dir) filenames = [ os.path.join(tmp_asset_dir, "file%d.txt" % n) for n in range(10) ] for filename in filenames: _write_string_to_file(filename, "I am file %s" % filename) with tf.Graph().as_default() as graph: assets = [ tf.constant(f, name=os.path.basename(f)) for f in filenames ] for asset in assets: graph.add_to_collection(tf_v1.GraphKeys.ASSET_FILEPATHS, asset) saved_model_lib.add_signature("default", {}, {"default": assets[0]}) handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) saved_model_proto = copy.deepcopy(handler._proto) export_dir = os.path.join(self.get_temp_dir(), "assets_key_test") saved_model_lib._make_assets_key_collection(saved_model_proto, export_dir) meta_graph = list(saved_model_proto.meta_graphs)[0] asset_tensor_names = [] for asset_any_proto in meta_graph.collection_def[ tf_v1.saved_model.constants.ASSETS_KEY].any_list.value: asset_proto = meta_graph_pb2.AssetFileDef() asset_any_proto.Unpack(asset_proto) asset_tensor_names.append(asset_proto.tensor_info.name) self.assertEqual(asset_tensor_names, sorted(asset_tensor_names))
def testSignatureImplementationIsInvisible(self): with tf.Graph().as_default() as graph: saved_model_lib.add_signature("test", {}, {}) self.assertEqual(graph.get_all_collection_keys(), []) handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) meta_graph, = handler.meta_graphs self.assertEqual(len(meta_graph.collection_def), 0) self.assertEqual(len(meta_graph.signature_def), 1)
def testEmptyCollectionsDoNotShowUpInMetaGraphDef(self): with tf.Graph().as_default() as graph: tf.Variable("name") self.assertEqual(len(graph.get_all_collection_keys()), 2) for collection_key in graph.get_all_collection_keys(): del graph.get_collection_ref(collection_key)[:] handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) meta_graph, = handler.meta_graphs self.assertEqual(len(meta_graph.collection_def), 0)
def testTags(self): with tf.Graph().as_default() as graph: saved_model_lib.add_signature("default", {}, {"default": tf.constant(1)}) handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph, ["tag1"]) handler.add_graph_copy(graph, ["tag1", "tag2"]) self.assertAllEqual(sorted(handler.get_tags()), sorted([set(["tag1"]), set(["tag1", "tag2"])])) self.assertIsNotNone(handler.get_meta_graph_copy(["tag1"])) self.assertIsNotNone(handler.get_meta_graph_copy(["tag2", "tag1"])) with self.assertRaises(KeyError): handler.get_meta_graph_copy(["tag2"])
def testTags(self): graph = tf.Graph() handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph, ["tag1"]) handler.add_graph_copy(graph, ["tag1", "tag2"]) self.assertAllEqual(sorted(handler.get_tags()), sorted([set(["tag1"]), set(["tag1", "tag2"])])) self.assertTrue(handler.get_meta_graph_copy(["tag1"]) is not None) self.assertTrue( handler.get_meta_graph_copy(["tag2", "tag1"]) is not None) with self.assertRaises(KeyError): handler.get_meta_graph_copy(["tag2"])
def testSignatures(self): with tf.Graph().as_default() as graph: input_a = tf.constant(2) input_b = tf.constant(3) mul = input_a * input_b saved_model_lib.add_signature("six", {}, {"out": mul}) saved_model_lib.add_signature("mul2", {"in": input_b}, {"out": mul}) handler = saved_model_lib.SavedModelHandler() handler.add_graph_copy(graph) signatures = handler.get_meta_graph_copy().signature_def self.assertEqual(set(signatures.keys()), set(["six", "mul2"])) self.assertAllEqual(list(signatures["six"].inputs.keys()), []) self.assertAllEqual(list(signatures["six"].outputs.keys()), ["out"]) self.assertAllEqual(list(signatures["mul2"].inputs.keys()), ["in"]) self.assertAllEqual(list(signatures["mul2"].outputs.keys()), ["out"])
def create_module_spec(module_fn, tags_and_args=None, drop_collections=None): """Creates a ModuleSpec from a function that builds the module's graph. The `module_fn` is called on a new graph (not the current one) to build the graph of the module and define its signatures via `hub.add_signature()`. Example: ```python # Define a text embedding module. def my_text_module_fn(): text_input = tf.placeholder(dtype=tf.string, shape=[None]) embeddings = compute_embedding(text) hub.add_signature(inputs=text_input, outputs=embeddings) ``` See `add_signature()` for documentation on adding multiple input/output signatures. NOTE: In anticipation of future TF-versions, `module_fn` is called on a graph that uses resource variables by default. If you want old-style variables then you can use `with tf.variable_scope("", use_resource=False)` in `module_fn`. Multiple graph variants can be defined by using the `tags_and_args` argument. For example, the code: ```python hub.create_module_spec( module_fn, tags_and_args=[({"train"}, {"is_training":True}), (set(), {"is_training":False})]) ``` calls `module_fn` twice, once as `module_fn(is_training=True)` and once as `module_fn(is_training=False)` to define the respective graph variants: for training with tags {"train"} and for inference with the empty set of tags. Using the empty set aligns the inference case with the default in Module.__init__(). Args: module_fn: a function to build a graph for the Module. tags_and_args: Optional list of tuples (tags, kwargs) of tags and keyword args used to define graph variants. If omitted, it is interpreted as [set(), {}], meaning `module_fn` is called once with no args. drop_collections: list of collection to drop. Returns: A ModuleSpec. Raises: ValueError: if it fails to construct the ModuleSpec due to bad or unsupported values in the arguments or in the graphs constructed by `module_fn`. """ if not drop_collections: drop_collections = [] report_tags = True if not tags_and_args: tags_and_args = [(set(), {})] report_tags = False saved_model_handler = saved_model_lib.SavedModelHandler() for tags, args in tags_and_args: with tf.Graph().as_default() as graph: with tf.variable_scope("", use_resource=True): module_fn(**args) for collection_key in drop_collections: del tf.get_collection_ref(collection_key)[:] err = find_state_op_colocation_error(graph, tags if report_tags else None) if err: raise ValueError(err) saved_model_handler.add_graph_copy(graph, tags=tags) return _ModuleSpec(saved_model_handler, checkpoint_variables_path=None)