def testMaybeSessionBundleDir(self):
     base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH)
     self.assertTrue(session_bundle.maybe_session_bundle_dir(base_path))
     base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
     self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
     base_path = "complete_garbage"
     self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
 def testMaybeSessionBundleDir(self):
     base_path = tf.test.test_src_dir_path(SESSION_BUNDLE_PATH)
     self.assertTrue(session_bundle.maybe_session_bundle_dir(base_path))
     base_path = tf.test.test_src_dir_path(SAVED_MODEL_PATH)
     self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
     base_path = "complete_garbage"
     self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
Example #3
0
def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
                                                        tags=None,
                                                        target="",
                                                        config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    tags: Set of string tags to identify the required MetaGraphDef when model is
          saved as SavedModel. These should correspond to the tags used when
          saving the variables using the SavedModel `save()` API.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
            tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  metagraph_def = None
  sess = None
  if loader.maybe_saved_model_directory(export_dir):
    sess = session.Session(target, graph=None, config=config)
    metagraph_def = loader.load(sess, tags, export_dir)
  elif session_bundle.maybe_session_bundle_dir(export_dir):
    sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir,
                                                                     target,
                                                                     config)
  else:
    raise RuntimeError("SessionBundle or SavedModelBundle not found at "
                       "specified export location: %s" % export_dir)

  return sess, metagraph_def
Example #4
0
def load_session_bundle_or_saved_model_bundle_from_path(export_dir,
                                                        tags=None,
                                                        target="",
                                                        config=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    tags: Set of string tags to identify the required MetaGraphDef when model is
          saved as SavedModel. These should correspond to the tags used when
          saving the variables using the SavedModel `save()` API.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
            tf.Session()

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  metagraph_def = None
  sess = None
  if loader.maybe_saved_model_directory(export_dir):
    sess = session.Session(target, graph=None, config=config)
    metagraph_def = loader.load(sess, tags, export_dir)
  elif session_bundle.maybe_session_bundle_dir(export_dir):
    sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir,
                                                                     target,
                                                                     config)
  else:
    raise RuntimeError("SessionBundle or SavedModelBundle not found at "
                       "specified export location: %s" % export_dir)

  return sess, metagraph_def