def testGetSavedModelTagSets(self):
    saved_model_dir = os.path.join(test.get_temp_dir(), "test_tags")
    builder = saved_model_builder.SavedModelBuilder(saved_model_dir)

    # Graph with a single variable. SavedModel invoked to:
    # - add with weights.
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 42)
      builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])

    # Graph that updates the single variable. SavedModel invoked to:
    # - simply add the model (weights are not updated).
    # - a single tag (from predefined constants).
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 43)
      builder.add_meta_graph([tag_constants.SERVING])

    # Graph that updates the single variable. SavedModel is invoked:
    # - to add the model (weights are not updated).
    # - multiple custom tags.
    with self.test_session(graph=ops.Graph()) as sess:
      self._init_and_validate_variable(sess, "v", 44)
      builder.add_meta_graph(["foo", "bar"])

    # Save the SavedModel to disk.
    builder.save()

    actual_tags = reader.get_saved_model_tag_sets(saved_model_dir)
    expected_tags = [["train"], ["serve"], ["foo", "bar"]]
    self.assertEqual(expected_tags, actual_tags)
Esempio n. 2
0
    def _from_saved_model(saved_model_dir):
        from tensorflow.python.tools import freeze_graph

        # must import here as tf.contrib is only available on TF 1.x
        from tensorflow.contrib.saved_model.python.saved_model import reader

        saved_model_tags = reader.get_saved_model_tag_sets(saved_model_dir)[0]
        if not saved_model_tags:
            msg = "Unsupported SavedModel directory format: no tag_sets available"
            raise NotImplementedError(msg)

        # get model outputs
        output_node_names = []
        if _get_version(tf.__version__) < _StrictVersion("1.13.1"):
            sess = tf.Session()
        else:
            sess = tf.compat.v1.Session()
        metagraph = tf.saved_model.loader.load(
            sess, saved_model_tags, saved_model_dir
        )
        for sd in metagraph.signature_def.values():
            output_node_names += [o.name.split(":")[0] for o in sd.outputs.values()]

        sess.close()

        # get frozen graph
        output_graph = mktemp()
        tf.compat.v1.reset_default_graph() if _get_version(tf.__version__) >= _StrictVersion("1.13.1") else tf.reset_default_graph()
        freeze_graph.freeze_graph(
            input_graph=None,
            input_saver=None,
            input_binary=None,
            input_checkpoint=None,
            output_node_names=",".join(output_node_names),
            restore_op_name=None,
            filename_tensor_name=None,
            output_graph=output_graph,
            clear_devices=True,
            initializer_nodes="",
            variable_names_whitelist="",
            variable_names_blacklist="",
            input_meta_graph=None,
            input_saved_model_dir=saved_model_dir,
            saved_model_tags=",".join(saved_model_tags),
        )

        if _get_version(tf.__version__) < _StrictVersion("1.13.1"):
            graph_def = tf.GraphDef()
            with open(output_graph, "rb") as f:
                graph_def.ParseFromString(f.read())
            graph_def = tf.graph_util.remove_training_nodes(graph_def)
        else:
            graph_def = tf.compat.v1.GraphDef()
            with open(output_graph, "rb") as f:
                graph_def.ParseFromString(f.read())
            graph_def = tf.compat.v1.graph_util.remove_training_nodes(graph_def)
        with tf.Graph().as_default() as graph:
            tf.graph_util.import_graph_def(graph_def, name="")
        return graph.as_graph_def(add_shapes=True)
Esempio n. 3
0
 def _get_tag_set(self):
     """Return the tag set of saved model, multiple metagraphs are not supported"""
     try:
         from tensorflow.contrib.saved_model.python.saved_model import reader
     except ImportError:
         raise ImportError(
             "InputConfiguration: Unable to import saved_model.reader which is "
             "required to get tag set from saved model.")
     tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
     return tag_sets[0]
Esempio n. 4
0
def _show_tag_sets(saved_model_dir):
    """Prints the tag-sets stored in SavedModel directory.
  Prints all the tag-sets for MetaGraphs stored in SavedModel directory.
  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
    tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
    print('The given SavedModel contains the following tag-sets:')
    for tag_set in sorted(tag_sets):
        print(', '.join(sorted(tag_set)))
Esempio n. 5
0
    def __get_tag_set(self) -> str:
        try:
            from tensorflow.contrib.saved_model.python.saved_model import reader
        except ImportError:
            raise ImportError(
                "InputConfiguration: Unable to import saved_model.reader which is "
                "required to get tag set from saved model.")

        tag_sets = reader.get_saved_model_tag_sets(self.model_path.as_posix())
        return tag_sets[0]
Esempio n. 6
0
 def _get_tag_set(self):
     """Return the tag set of saved model, multiple metagraphs are not supported"""
     try:
         from tensorflow.contrib.saved_model.python.saved_model import reader
     except ImportError:
         raise ImportError(
             "InputConfiguration: Unable to import saved_model.reader which is "
             "required to get tag set from saved model.")
     tag_sets = reader.get_saved_model_tag_sets(self._model_dir)
     return tag_sets[0]
Esempio n. 7
0
def _show_tag_sets(saved_model_dir):
  """Prints the tag-sets stored in SavedModel directory.

  Prints all the tag-sets for MetaGraphs stored in SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
  tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
  print('The given SavedModel contains the following tag-sets:')
  for tag_set in sorted(tag_sets):
    print(', '.join(sorted(tag_set)))
Esempio n. 8
0
def _show_all(saved_model_dir):
    """Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel.

  Prints all tag-set, SignatureDef and Inputs/Outputs information stored in
  SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
    tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
    for tag_set in sorted(tag_sets):
        tag_set = ', '.join(tag_set)
        print('\nMetaGraphDef with tag-set: \'' + tag_set +
              '\' contains the following SignatureDefs:')

        signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
        for signature_def_key in sorted(signature_def_map.keys()):
            print('\nsignature_def[\'' + signature_def_key + '\']:')
            _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)
Esempio n. 9
0
def _show_all(saved_model_dir):
  """Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel.

  Prints all tag-set, SignatureDef and Inputs/Outputs information stored in
  SavedModel directory.

  Args:
    saved_model_dir: Directory containing the SavedModel to inspect.
  """
  tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
  for tag_set in sorted(tag_sets):
    tag_set = ', '.join(tag_set)
    print('\nMetaGraphDef with tag-set: \'' + tag_set +
          '\' contains the following SignatureDefs:')

    signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
    for signature_def_key in sorted(signature_def_map.keys()):
      print('\nsignature_def[\'' + signature_def_key + '\']:')
      _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)