def testMaybeSessionBundleDir(self): base_path = test.test_src_dir_path(SESSION_BUNDLE_PATH) self.assertTrue(session_bundle.maybe_session_bundle_dir(base_path)) base_path = test.test_src_dir_path(SAVED_MODEL_PATH) self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path)) base_path = "complete_garbage" self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
def testMaybeSessionBundleDir(self): base_path = tf.test.test_src_dir_path(SESSION_BUNDLE_PATH) self.assertTrue(session_bundle.maybe_session_bundle_dir(base_path)) base_path = tf.test.test_src_dir_path(SAVED_MODEL_PATH) self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path)) base_path = "complete_garbage" self.assertFalse(session_bundle.maybe_session_bundle_dir(base_path))
def load_session_bundle_or_saved_model_bundle_from_path(export_dir, tags=None, target="", config=None): """Load session bundle from the given path. The function reads input from the export_dir, constructs the graph data to the default graph and restores the parameters for the session created. Args: export_dir: the directory that contains files exported by exporter. tags: Set of string tags to identify the required MetaGraphDef when model is saved as SavedModel. These should correspond to the tags used when saving the variables using the SavedModel `save()` API. target: The execution engine to connect to. See target in tf.Session() config: A ConfigProto proto with configuration options. See config in tf.Session() Returns: session: a tensorflow session created from the variable files. meta_graph: a meta graph proto saved in the exporter directory. Raises: RuntimeError: if the required files are missing or contain unrecognizable fields, i.e. the exported model is invalid. """ metagraph_def = None sess = None if loader.maybe_saved_model_directory(export_dir): sess = session.Session(target, graph=None, config=config) metagraph_def = loader.load(sess, tags, export_dir) elif session_bundle.maybe_session_bundle_dir(export_dir): sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir, target, config) else: raise RuntimeError("SessionBundle or SavedModelBundle not found at " "specified export location: %s" % export_dir) return sess, metagraph_def