def testMaybeSavedModelDir(self): base_path = test.test_src_dir_path("/python/saved_model") self.assertFalse(loader.maybe_saved_model_directory(base_path)) base_path = test.test_src_dir_path(SAVED_MODEL_PATH) self.assertTrue(loader.maybe_saved_model_directory(base_path)) base_path = "complete_garbage" self.assertFalse(loader.maybe_saved_model_directory(base_path))
def load_model(model_path, config=None): """Loads the model at the specified path. Args: model_path: the path to either session_bundle or SavedModel config: tf.ConfigProto containing session configuration options. Returns: A pair of (Session, SignatureDef) objects. Raises: PredictionError: if the model could not be loaded. """ if loader.maybe_saved_model_directory(model_path): try: session = tf_session.Session(target="", graph=None, config=config) meta_graph = loader.load(session, tags=[tag_constants.SERVING], export_dir=model_path) except Exception: # pylint: disable=broad-except raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Failed to load the model due to bad model data.") else: raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Cloud ML only supports TF 1.0 or above and models " "saved in SavedModel format.") if session is None: raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Failed to create session when loading the model") signature = _get_signature_from_meta_graph(session.graph, meta_graph) return session, signature
def load_tf_model(model_path, tags=(tag_constants.SERVING,), config=None): """Loads the model at the specified path. Args: model_path: the path to either session_bundle or SavedModel tags: the tags that determines the model to load. config: tf.ConfigProto containing session configuration options. Returns: A pair of (Session, map<string, SignatureDef>) objects. Raises: PredictionError: if the model could not be loaded. """ if loader.maybe_saved_model_directory(model_path): try: logging.info("Importing tensorflow.contrib in load_tf_model") # pylint: disable=redefined-outer-name,unused-variable,g-import-not-at-top import tensorflow as tf from tensorflow.python.framework.ops import Graph # pylint: enable=redefined-outer-name,unused-variable,g-import-not-at-top if tf.__version__.startswith("1.0"): session = tf_session.Session(target="", graph=None, config=config) else: session = tf_session.Session(target="", graph=Graph(), config=config) meta_graph = loader.load(session, tags=list(tags), export_dir=model_path) except Exception as e: # pylint: disable=broad-except raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Failed to load the model due to bad model data." " tags: %s\n%s" % (list(tags), str(e))) else: raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Cloud ML only supports TF 1.0 or above and models " "saved in SavedModel format.") if session is None: raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "Failed to create session when loading the model") if not meta_graph.signature_def: raise PredictionError(PredictionError.FAILED_TO_LOAD_MODEL, "MetaGraph must have at least one signature_def.") # Remove invalid signatures from the signature map. invalid_signatures = [] for signature_name in meta_graph.signature_def: try: signature = meta_graph.signature_def[signature_name] _update_dtypes(session.graph, signature.inputs) _update_dtypes(session.graph, signature.outputs) except ValueError as e: logging.warn("Error updating signature %s: %s", signature_name, str(e)) invalid_signatures.append(signature_name) for signature_name in invalid_signatures: del meta_graph.signature_def[signature_name] return session, meta_graph.signature_def
def load_model( model_path, tags=(tag_constants.SERVING, ), signature_name=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY, config=None): """Loads the model at the specified path. Args: model_path: the path to either session_bundle or SavedModel tags: the tags that determines the model to load. signature_name: the string used as the key to signature map to locate the serving signature. config: tf.ConfigProto containing session configuration options. Returns: A pair of (Session, SignatureDef) objects. Raises: PredictionError: if the model could not be loaded. """ if loader.maybe_saved_model_directory(model_path): try: session = tf_session.Session(target="", graph=None, config=config) meta_graph = loader.load(session, tags=list(tags), export_dir=model_path) except Exception: # pylint: disable=broad-except raise PredictionError( PredictionError.FAILED_TO_LOAD_MODEL, "Failed to load the model due to bad model data." " tags: %s" % tags) else: raise PredictionError( PredictionError.FAILED_TO_LOAD_MODEL, "Cloud ML only supports TF 1.0 or above and models " "saved in SavedModel format.") if session is None: raise PredictionError( PredictionError.FAILED_TO_LOAD_MODEL, "Failed to create session when loading the model") signature = _get_signature_from_meta_graph(session.graph, meta_graph, signature_name) return session, signature
def load_session_bundle_or_saved_model_bundle_from_path(export_dir, tags=None, target="", config=None): """Load session bundle from the given path. The function reads input from the export_dir, constructs the graph data to the default graph and restores the parameters for the session created. Args: export_dir: the directory that contains files exported by exporter. tags: Set of string tags to identify the required MetaGraphDef when model is saved as SavedModel. These should correspond to the tags used when saving the variables using the SavedModel `save()` API. target: The execution engine to connect to. See target in tf.Session() config: A ConfigProto proto with configuration options. See config in tf.Session() Returns: session: a tensorflow session created from the variable files. meta_graph: a meta graph proto saved in the exporter directory. Raises: RuntimeError: if the required files are missing or contain unrecognizable fields, i.e. the exported model is invalid. """ metagraph_def = None sess = None if loader.maybe_saved_model_directory(export_dir): sess = session.Session(target, graph=None, config=config) metagraph_def = loader.load(sess, tags, export_dir) elif session_bundle.maybe_session_bundle_dir(export_dir): sess, metagraph_def = _load_saved_model_from_session_bundle_path(export_dir, target, config) else: raise RuntimeError("SessionBundle or SavedModelBundle not found at " "specified export location: %s" % export_dir) return sess, metagraph_def