def testGetSavedModelTagSets(self): saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags") builder = saved_model_builder.SavedModelBuilder(saved_model_dir) # Graph with a single variable. SavedModel invoked to: # - add with weights. # - a single tag (from predefined constants). with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING]) # Graph that updates the single variable. SavedModel invoked to: # - simply add the model (weights are not updated). # - a single tag (from predefined constants). with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) builder.add_meta_graph([tag_constants.SERVING]) # Graph that updates the single variable. SavedModel is invoked: # - to add the model (weights are not updated). # - multiple custom tags. with self.test_session(graph=ops.Graph()) as sess: self._init_and_validate_variable(sess, "v", 44) builder.add_meta_graph(["foo", "bar"]) # Save the SavedModel to disk. builder.save() actual_tags = reader.get_saved_model_tag_sets(saved_model_dir) expected_tags = [["train"], ["serve"], ["foo", "bar"]] self.assertEqual(expected_tags, actual_tags)
def _from_saved_model(saved_model_dir): from tensorflow.python.tools import freeze_graph # must import here as tf.contrib is only available on TF 1.x from tensorflow.contrib.saved_model.python.saved_model import reader saved_model_tags = reader.get_saved_model_tag_sets(saved_model_dir)[0] if not saved_model_tags: msg = "Unsupported SavedModel directory format: no tag_sets available" raise NotImplementedError(msg) # get model outputs output_node_names = [] if _get_version(tf.__version__) < _StrictVersion("1.13.1"): sess = tf.Session() else: sess = tf.compat.v1.Session() metagraph = tf.saved_model.loader.load( sess, saved_model_tags, saved_model_dir ) for sd in metagraph.signature_def.values(): output_node_names += [o.name.split(":")[0] for o in sd.outputs.values()] sess.close() # get frozen graph output_graph = mktemp() tf.compat.v1.reset_default_graph() if _get_version(tf.__version__) >= _StrictVersion("1.13.1") else tf.reset_default_graph() freeze_graph.freeze_graph( input_graph=None, input_saver=None, input_binary=None, input_checkpoint=None, output_node_names=",".join(output_node_names), restore_op_name=None, filename_tensor_name=None, output_graph=output_graph, clear_devices=True, initializer_nodes="", variable_names_whitelist="", variable_names_blacklist="", input_meta_graph=None, input_saved_model_dir=saved_model_dir, saved_model_tags=",".join(saved_model_tags), ) if _get_version(tf.__version__) < _StrictVersion("1.13.1"): graph_def = tf.GraphDef() with open(output_graph, "rb") as f: graph_def.ParseFromString(f.read()) graph_def = tf.graph_util.remove_training_nodes(graph_def) else: graph_def = tf.compat.v1.GraphDef() with open(output_graph, "rb") as f: graph_def.ParseFromString(f.read()) graph_def = tf.compat.v1.graph_util.remove_training_nodes(graph_def) with tf.Graph().as_default() as graph: tf.graph_util.import_graph_def(graph_def, name="") return graph.as_graph_def(add_shapes=True)
def _get_tag_set(self): """Return the tag set of saved model, multiple metagraphs are not supported""" try: from tensorflow.contrib.saved_model.python.saved_model import reader except ImportError: raise ImportError( "InputConfiguration: Unable to import saved_model.reader which is " "required to get tag set from saved model.") tag_sets = reader.get_saved_model_tag_sets(self._model_dir) return tag_sets[0]
def _show_tag_sets(saved_model_dir): """Prints the tag-sets stored in SavedModel directory. Prints all the tag-sets for MetaGraphs stored in SavedModel directory. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) print('The given SavedModel contains the following tag-sets:') for tag_set in sorted(tag_sets): print(', '.join(sorted(tag_set)))
def __get_tag_set(self) -> str: try: from tensorflow.contrib.saved_model.python.saved_model import reader except ImportError: raise ImportError( "InputConfiguration: Unable to import saved_model.reader which is " "required to get tag set from saved model.") tag_sets = reader.get_saved_model_tag_sets(self.model_path.as_posix()) return tag_sets[0]
def _get_tag_set(self): """Return the tag set of saved model, multiple metagraphs are not supported""" try: from tensorflow.contrib.saved_model.python.saved_model import reader except ImportError: raise ImportError( "InputConfiguration: Unable to import saved_model.reader which is " "required to get tag set from saved model.") tag_sets = reader.get_saved_model_tag_sets(self._model_dir) return tag_sets[0]
def _show_tag_sets(saved_model_dir): """Prints the tag-sets stored in SavedModel directory. Prints all the tag-sets for MetaGraphs stored in SavedModel directory. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) print('The given SavedModel contains the following tag-sets:') for tag_set in sorted(tag_sets): print(', '.join(sorted(tag_set)))
def _show_all(saved_model_dir): """Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel. Prints all tag-set, SignatureDef and Inputs/Outputs information stored in SavedModel directory. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) for tag_set in sorted(tag_sets): tag_set = ', '.join(tag_set) print('\nMetaGraphDef with tag-set: \'' + tag_set + '\' contains the following SignatureDefs:') signature_def_map = get_signature_def_map(saved_model_dir, tag_set) for signature_def_key in sorted(signature_def_map.keys()): print('\nsignature_def[\'' + signature_def_key + '\']:') _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)
def _show_all(saved_model_dir): """Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel. Prints all tag-set, SignatureDef and Inputs/Outputs information stored in SavedModel directory. Args: saved_model_dir: Directory containing the SavedModel to inspect. """ tag_sets = reader.get_saved_model_tag_sets(saved_model_dir) for tag_set in sorted(tag_sets): tag_set = ', '.join(tag_set) print('\nMetaGraphDef with tag-set: \'' + tag_set + '\' contains the following SignatureDefs:') signature_def_map = get_signature_def_map(saved_model_dir, tag_set) for signature_def_key in sorted(signature_def_map.keys()): print('\nsignature_def[\'' + signature_def_key + '\']:') _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)