Beispiel #1
0
 def test_save_graph_model_default_session(self):
   x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name='inp')
   weights = tf.constant(1., shape=(10, 2), name='weights')
   model_path = os.path.join(tf.test.get_temp_dir(), 'default')
   utils.save_graph_model(
       tf.Session(), model_path, {'x': x}, {'w': weights}, ['tag'])
   self.assertTrue(os.path.isfile(os.path.join(model_path, 'saved_model.pb')))
Beispiel #2
0
    def save_model_with_metadata(self, file_path):
        """Saves the model and the generated metadata to the given file path.

    Args:
      file_path: Path to save the model and the metadata. It can be a GCS bucket
        or a local folder. The folder needs to be empty.

    Returns:
      Full file path where the model and the metadata are written.
    """
        md_dict = self.get_metadata()

        if not self._serving_inputs:
            self._serving_inputs = self._build_input_signature(
                self._inputs, self._session.graph)
        if not self._serving_outputs:
            self._serving_outputs = self._build_output_signature(
                self._outputs, self._session.graph)

        utils.save_graph_model(self._session, file_path, self._serving_inputs,
                               self._serving_outputs, self._tags,
                               **self._saved_model_args)

        common_utils.write_metadata_to_file(md_dict, file_path)
        return file_path
Beispiel #3
0
 def test_save_graph_model_kwargs(self):
   x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name='inp')
   weights = tf.constant(1., shape=(10, 2), name='weights')
   model_path = os.path.join(tf.test.get_temp_dir(), 'kwargs')
   utils.save_graph_model(
       tf.Session(),
       model_path, {'x': x}, {'w': weights}, ['tag'],
       main_op=tf.tables_initializer(),
       strip_default_attrs=False)
   self.assertTrue(os.path.isfile(os.path.join(model_path, 'saved_model.pb')))
Beispiel #4
0
 def test_save_graph_model_explicit_session(self):
   sess = tf.Session(graph=tf.Graph())
   with sess.graph.as_default():
     x = tf.placeholder(shape=[None, 10], dtype=tf.float32, name='inp')
     weights = tf.constant(1., shape=(10, 2), name='weights')
   model_path = os.path.join(tf.test.get_temp_dir(), 'explicit')
   utils.save_graph_model(sess, model_path, {'x': x}, {'w': weights}, ['tag'])
   self.assertTrue(os.path.isfile(os.path.join(model_path, 'saved_model.pb')))
   tf.reset_default_graph()
   loading_session = tf.Session(graph=tf.Graph())
   with loading_session.graph.as_default():
     tf.saved_model.loader.load(loading_session, ['tag'], model_path)
     self.assertIn(x.op.name,
                   [n.name for n in loading_session.graph.as_graph_def().node])
     self.assertIn(weights.op.name,
                   [n.name for n in loading_session.graph.as_graph_def().node])